diff --git a/omniadvisor/pom.xml b/omniadvisor/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..c2c31b7c59ac0513dfe93fe421bb9afc1f7ad322 --- /dev/null +++ b/omniadvisor/pom.xml @@ -0,0 +1,989 @@ + + + 4.0.0 + + com.huawei.boostkit + omniadvisor-log-analyzer + 1.0.0 + + + Kunpeng BoostKit + + + + aarch64 + + 3.2.0 + 3.1.1 + 0.10.0 + + 12.16.0 + 2.2.5 + 7.5 + 7.6 + + 8.0.11 + + 4.1.1 + 1.3.4 + 1.19 + 1.2.83 + 2.1.6 + 1.9.2 + 2.10.0 + 2.10.5.1 + 2.10.0 + 1.1.1 + 2.4.7 + + 1.9.4 + + 2.0.7 + 1.7.30 + 2.20.0 + + 27.0-jre + 4.0 + 1.3.4 + 9.8.1 + 4.2.1 + 3.4.14 + + 1.1.3 + 3.10 + 3.4.1 + 3.6 + 2.6 + 1.10.0 + 2.8.0 + + 4.11 + 1.10.19 + + 2.12.15 + 2.12 + 2.0 + incremental + + 8 + 8 + 3.1.2 + 3.8.1 + + + + + + org.scala-lang + scala-library + ${scala.version} + + + org.scala-lang + scala-compiler + ${scala.version} + + + com.jsuereth + scala-arm_${scala.compat.version} + ${scala.arm.version} + + + org.scala-lang + scala-library + + + + + + + org.apache.tez + tez-api + ${tez.version} + + + log4j + log4j + + + commons-io + commons-io + + + org.apache.hadoop + hadoop-annotations + + + org.apache.commons + commons-compress + + + com.google.guava + guava + + + com.google.inject + guice + + + com.google.inject.extensions + guice-servlet + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + javax.xml.bind + jaxb-api + + + org.codehaus.jettison + jettison + + + org.slf4j + slf4j-api + + + org.apache.hadoop + hadoop-auth + + + org.apache.hadoop + hadoop-common + + + org.apache.hadoop + hadoop-hdfs-client + + + org.apache.hadoop + hadoop-yarn-api + + + org.apache.hadoop + hadoop-yarn-common + + + org.apache.hadoop + hadoop-yarn-client + + + + + org.apache.tez + tez-common + ${tez.version} + + + * + * + + + + + org.apache.tez + tez-dag + ${tez.version} + + + * + * + + + + + + + org.apache.hadoop + hadoop-auth + ${hadoop.version} + + + log4j + log4j + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + org.slf4j + slf4j-reload4j + + + ch.qos.reload4j + reload4j + + + commons-io + commons-io + + + commons-logging + commons-logging + + + com.nimbusds + nimbus-jose-jwt + + + com.google.guava + guava + + + org.apache.zookeeper + zookeeper + + + commons-codec + commons-codec + + + net.minidev + json-smart + + + + + org.apache.hadoop + hadoop-common + ${hadoop.version} + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + org.slf4j + slf4j-reload4j + + + ch.qos.reload4j + reload4j + + + org.apache.avro + * + + + org.apache.commons + commons-configuration2 + + + org.apache.commons + commons-lang3 + + + org.apache.commons + commons-math3 + + + org.apache.commons + commons-text + + + com.google.guava + guava + + + com.google.guava + guava + + + org.codehaus.woodstox + stax2-api + + + org.apache.zookeeper + zookeeper + + + commons-logging + commons-logging + + + com.fasterxml.jackson.core + jackson-databind + + + javax.activation + activation + + + javax.ws.rs + jsr311-api + + + + + org.apache.hadoop + hadoop-hdfs-client + ${hadoop.version} + runtime + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-databind + + + log4j + log4j + + + + + org.apache.hadoop + hadoop-yarn-server-resourcemanager + ${hadoop.version} + + + * + * + + + + + org.apache.hadoop + hadoop-yarn-api + ${hadoop.version} + + + * + * + + + + + + org.apache.zookeeper + zookeeper + ${zookeeper.version} + runtime + + + org.slf4j + slf4j-api + + + log4j + log4j + + + org.slf4j + slf4j-log4j12 + + + com.google.code.findbugs + jsr305 + + + + + + + org.apache.spark + spark-core_${scala.compat.version} + ${spark.version} + + + com.typesage.akka + * + + + org.apache.avro + * + + + org.apache.hadoop + * + + + net.razorvine + * + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + org.slf4j + slf4j-reload4j + + + ch.qos.reload4j + reload4j + + + org.apache.commons + commons-text + + + org.apache.commons + commons-lang3 + + + commons-net + commons-net + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + org.slf4j + jcl-over-slf4j + + + com.google.code.findbugs + jsr305 + + + org.slf4j + jul-to-slf4j + + + io.netty + netty-all + + + org.scala-lang + scala-library + + + org.scala-lang + scala-reflect + + + org.scala-lang + scala-compiler + + + org.scala-lang.modules + scala-xml_2.12 + + + org.apache.zookeeper + zookeeper + + + org.apache.curator + curator-recipes + + + + + org.apache.spark + spark-kvstore_${scala.compat.version} + ${spark.version} + + + + + io.ebean + ebean + ${ebean.version} + runtime + + + org.slf4j + slf4j-api + + + + + io.ebean + ebean-api + ${ebean.version} + + + org.slf4j + slf4j-api + + + + + io.ebean + ebean-querybean + ${ebean.version} + + + io.ebean + ebean-annotation + ${ebean-annotation.version} + + + io.ebean + ebean-ddl-generator + ${ebean.version} + runtime + + + io.ebean + ebean-migration + + + + + io.ebean + ebean-migration + ${ebean.version} + runtime + + + io.ebean + querybean-generator + ${ebean.version} + provided + + + org.codehaus.woodstox + stax2-api + ${stax2-api.version} + runtime + + + + + com.google.guava + guava + ${guava.version} + + + com.google.code.findbugs + jsr305 + + + + + + com.nimbusds + nimbus-jose-jwt + ${nimbus-jose-jwt.version} + + + + + com.sun.jersey + jersey-client + ${jersey-client.version} + + + org.codehaus.jettison + jettison + ${jettison.version} + + + commons-logging + commons-logging + ${commons-logging.version} + runtime + + + org.apache.commons + commons-text + ${commons-text.version} + + + org.apache.commons + commons-lang3 + + + + + commons-lang + commons-lang + ${commons-lang.version} + + + org.apache.commons + commons-lang3 + ${commons-lang3.version} + runtime + + + org.apache.commons + commons-configuration2 + ${commons-configuration2.version} + + + org.apache.commons + commons-lang3 + + + commons-logging + commons-logging + + + org.apache.commons + commons-text + + + + + + + mysql + mysql-connector-java + ${mysql.jdbc.version} + + + com.sun.jersey + jersey-core + + + com.sun.jersey + jersey-server + + + org.slf4j + slf4j-log4j12 + + + com.google.protobuf + protobuf-java + + + + + + + com.alibaba + fastjson + ${fastjon.version} + + + + com.fasterxml.jackson.core + jackson-databind + 2.10.0 + + + org.codehaus.jackson + jackson-mapper-asl + ${jackson.version} + + + com.fasterxml.jackson.module + jackson-module-scala_${scala.compat.version} + ${jackson-module-scala.version} + + + org.scala-lang + scala-library + + + + + com.fasterxml.jackson.core + jackson-core + ${jackson-core.version} + + + + net.minidev + json-smart + ${json-smart.version} + + + jakarta.ws.rs + jakarta.ws.rs-api + ${jakarta.version} + + + io.ebean + persistence-api + ${ebean-persistence.version} + + + io.ebean + ebean-datasource-api + ${ebean-datasource.version} + + + org.codehaus.jackson + jackson-core-asl + ${jackson.version} + + + + org.slf4j + slf4j-api + ${slf4j.version} + + + org.slf4j + jul-to-slf4j + ${slf4j.version} + runtime + + + org.slf4j + jcl-over-slf4j + ${slf4j.version} + runtime + + + org.apache.logging.log4j + log4j-slf4j2-impl + ${log4j.version} + runtime + + + org.apache.logging.log4j + log4j-api + + + org.slf4j + slf4j-api + + + + + org.slf4j + slf4j-log4j12 + ${slf4j-log4j12.version} + runtime + + + log4j + log4j + + + org.slf4j + slf4j-api + + + + + org.apache.logging.log4j + log4j-api + ${log4j.version} + runtime + + + org.apache.logging.log4j + log4j-core + ${log4j.version} + runtime + + + + org.apache.logging.log4j + log4j-1.2-api + ${log4j.version} + runtime + + + + + junit + junit + ${junit.version} + test + + + org.mockito + mockito-all + ${mockito-all.version} + test + + + org.apache.hadoop + hadoop-minikdc + ${hadoop.version} + test + + + org.slf4j + slf4j-log4j12 + + + + + + + boostkit-${project.artifactId}-${project.version}-${dep.os.arch} + + + src/main/resources + + * + */* + + + + + + + org.apache.maven.plugins + maven-resources-plugin + 3.3.1 + + + copy-resources + validate + + copy-resources + + + ${project.build.directory}/resources + + + src/main/resources + true + + + + + + + + net.alchim31.maven + scala-maven-plugin + 4.7.2 + + ${scala.recompile.mode} + + + + scala-compile-first + process-resources + + add-source + compile + + + + scala-test-compile + process-test-resources + + testCompile + + + + + + org.apache.maven.plugins + maven-compiler-plugin + ${maven.compiler.plugin.version} + + ${maven.compiler.source} + ${maven.compiler.target} + + + + compile + + compile + + + + + + io.repaint.maven + tiles-maven-plugin + 2.24 + true + + + io.ebean.tile:enhancement:${ebean.version} + + + + + org.apache.maven.plugins + maven-jar-plugin + ${maven.jar.plugin.version} + + + false + + + + + org.apache.maven.plugins + maven-dependency-plugin + 2.10 + + + copy-dependencies + package + + copy-dependencies + + + runtime + ${project.build.directory}/lib + + + + + + org.apache.maven.plugins + maven-assembly-plugin + 3.2.0 + + false + ${project.build.directory} + boostkit-${project.artifactId}-${project.version}-${dep.os.arch} + + src/main/assembly/assembly.xml + + + + + make-assembly + package + + single + + + + + + + \ No newline at end of file diff --git a/omniadvisor/src/main/assembly/assembly.xml b/omniadvisor/src/main/assembly/assembly.xml new file mode 100644 index 0000000000000000000000000000000000000000..d7c6ccb8df39e30417e8b92e1c7daf5fbec61213 --- /dev/null +++ b/omniadvisor/src/main/assembly/assembly.xml @@ -0,0 +1,23 @@ + + bin + + zip + + + + ${basedir}/target + + *.jar + + ./ + + + ${basedir}/target/resources/conf + ./conf + + + ${basedir}/target/lib + ./lib + + + \ No newline at end of file diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/OmniAdvisor.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/OmniAdvisor.java new file mode 100644 index 0000000000000000000000000000000000000000..315e0208aa561752566f01e5aef1976a5635b77a --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/OmniAdvisor.java @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor; + +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import com.huawei.boostkit.omniadvisor.executor.OmniAdvisorRunner; +import org.apache.commons.lang.time.DateUtils; + +import java.text.ParseException; +import java.util.Arrays; +import java.util.Date; +import java.util.List; +import java.util.stream.Collectors; + +public final class OmniAdvisor { + private static final int REQUIRED_PARAMS_LENGTH = 4; + private static final String[] TIME_PARSE_PATTERNS = {"yyyy-MM-dd HH:mm:ss"}; + + private OmniAdvisor() {} + + public static void main(String[] args) { + List params = removeEmptyArgs(args); + + if (params.size() != REQUIRED_PARAMS_LENGTH) { + throw new OmniAdvisorException("The number of parameters is abnormal. Only four parameters are supported."); + } + + Date startDate; + Date finishDate; + try { + startDate = DateUtils.parseDate(params.get(0), TIME_PARSE_PATTERNS); + finishDate = DateUtils.parseDate(params.get(1), TIME_PARSE_PATTERNS); + } catch (ParseException e) { + throw new OmniAdvisorException("Unsupported date format. Only the 'yyyy-MM-dd HH:mm:ss' is supported", e); + } + + long startTimeMills = startDate.getTime(); + long finishedTimeMills = finishDate.getTime(); + + if (startTimeMills > finishedTimeMills) { + throw new OmniAdvisorException("start time cannot be greater than finish time"); + } + + OmniAdvisorContext.initContext(params.get(2), params.get(3)); + OmniAdvisorRunner runner = new OmniAdvisorRunner(startTimeMills, finishedTimeMills); + runner.run(); + } + + private static List removeEmptyArgs(String[] args) { + return Arrays.stream(args).filter(arg -> !arg.isEmpty()).collect(Collectors.toList()); + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/OmniAdvisorContext.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/OmniAdvisorContext.java new file mode 100644 index 0000000000000000000000000000000000000000..a3e52441e4b809a7732ba806f1be107aade20c19 --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/OmniAdvisorContext.java @@ -0,0 +1,151 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor; + +import com.google.common.collect.ImmutableList; +import com.huawei.boostkit.omniadvisor.configuration.DBConfigure; +import com.huawei.boostkit.omniadvisor.configuration.OmniAdvisorConfigure; +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import com.huawei.boostkit.omniadvisor.fetcher.FetcherFactory; +import com.huawei.boostkit.omniadvisor.models.AppResult; +import io.ebean.Finder; +import org.apache.commons.configuration2.PropertiesConfiguration; +import org.apache.commons.configuration2.builder.FileBasedConfigurationBuilder; +import org.apache.commons.configuration2.builder.fluent.Configurations; +import org.apache.commons.configuration2.ex.ConfigurationException; +import org.apache.hadoop.conf.Configuration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Locale; + +import static java.lang.String.format; + +public final class OmniAdvisorContext { + private static final Logger LOG = LoggerFactory.getLogger(OmniAdvisorContext.class); + private static final String CONFIG_FILE_NAME = "omniAdvisorLogAnalyzer.properties"; + private static final List DEFAULT_HADOOP_CONFIG_FILES = ImmutableList.of("hdfs-site.xml", "core-site.xml"); + private static final String ENCODING = StandardCharsets.UTF_8.displayName(Locale.ENGLISH); + private static final Configuration HADOOP_CONF; + + private static OmniAdvisorContext instance = null; + + private final OmniAdvisorConfigure omniAdvisorConfigure; + private final FetcherFactory fetcherFactory; + + static { + HADOOP_CONF = new Configuration(); + for (String configFileName : DEFAULT_HADOOP_CONFIG_FILES) { + URL configFile = Thread.currentThread().getContextClassLoader().getResource(configFileName); + if (configFile != null) { + LOG.info("Add resource {} to hadoop config", configFile); + HADOOP_CONF.addResource(configFile); + } + } + } + + private Finder finder = new Finder<>(AppResult.class); + + private OmniAdvisorContext() { + this(false, null, null); + } + + private OmniAdvisorContext(String user, String passwd) { + this(true, user, passwd); + } + + private OmniAdvisorContext(boolean initDatabase, String user, String passwd) { + PropertiesConfiguration configuration = loadConfigure(); + if (initDatabase) { + initDataSource(configuration, user, passwd); + } + this.omniAdvisorConfigure = loadOmniTuningConfig(configuration); + this.fetcherFactory = loadFetcherFactory(configuration); + } + + public static void initContext(String user, String passwd) { + if (instance == null) { + instance = new OmniAdvisorContext(user, passwd); + } else { + LOG.warn("OmniTuningContext has been instantiated"); + } + } + + // only use for unit test + public static void initContext() { + if (instance == null) { + instance = new OmniAdvisorContext(); + } else { + LOG.warn("OmniTuningContext has been instantiated"); + } + } + + public static OmniAdvisorContext getInstance() { + if (instance == null) { + throw new OmniAdvisorException("OmniTuningContext has not been instantiated"); + } + return instance; + } + + public static Configuration getHadoopConfig() { + return HADOOP_CONF; + } + + public OmniAdvisorConfigure getOmniAdvisorConfigure() { + return omniAdvisorConfigure; + } + + public FetcherFactory getFetcherFactory() { + return fetcherFactory; + } + + public Finder getFinder() { + return finder; + } + + public void setFinder(Finder finder) { + this.finder = finder; + } + + private PropertiesConfiguration loadConfigure() { + try { + Configurations configurations = new Configurations(); + URL configFileUrl = Thread.currentThread().getContextClassLoader().getResource(CONFIG_FILE_NAME); + if (configFileUrl == null) { + throw new OmniAdvisorException("Config file is missing"); + } + FileBasedConfigurationBuilder.setDefaultEncoding(OmniAdvisorConfigure.class, ENCODING); + return configurations.properties(configFileUrl); + } catch (ConfigurationException e) { + throw new OmniAdvisorException(format("Failed to read config file, %s", e)); + } + } + + private void initDataSource(PropertiesConfiguration configuration, String user, String passwd) { + DBConfigure.initDatabase(configuration, user, passwd); + } + + private OmniAdvisorConfigure loadOmniTuningConfig(PropertiesConfiguration configuration) { + return new OmniAdvisorConfigure(configuration); + } + + private FetcherFactory loadFetcherFactory(PropertiesConfiguration configuration) { + return new FetcherFactory(configuration); + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/analysis/AnalyticJob.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/analysis/AnalyticJob.java new file mode 100644 index 0000000000000000000000000000000000000000..95e58e0237aa5345ed5d41e2e469d7124f747581 --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/analysis/AnalyticJob.java @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.analysis; + +import com.huawei.boostkit.omniadvisor.fetcher.FetcherType; + +public interface AnalyticJob { + String getApplicationId(); + + FetcherType getType(); +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/configuration/DBConfigure.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/configuration/DBConfigure.java new file mode 100644 index 0000000000000000000000000000000000000000..6164c896ca2af66d2294e31059f8109829ed487d --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/configuration/DBConfigure.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.configuration; + +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import com.huawei.boostkit.omniadvisor.models.AppResult; +import io.ebean.DatabaseFactory; +import io.ebean.config.DatabaseConfig; +import io.ebean.datasource.DataSourceFactory; +import org.apache.commons.configuration2.PropertiesConfiguration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Properties; + +import static java.lang.String.format; + +public final class DBConfigure { + private static final Logger LOG = LoggerFactory.getLogger(DBConfigure.class); + + private static final String DB_DEFAULT_DRIVER = "com.mysql.cj.jdbc.Driver"; + private static final String DB_DRIVER_KEY = "datasource.db.driver"; + private static final String DB_URL_KEY = "datasource.db.url"; + private static final String DB_USERNAME_KEY = "datasource.db.username"; + private static final String DB_PASSWORD_KEY = "datasource.db.password"; + + private DBConfigure() {} + + public static void initDatabase(PropertiesConfiguration configuration, String userName, String passWord) { + Properties databaseProperties = new Properties(); + databaseProperties.put(DB_DRIVER_KEY, configuration.getString(DB_DRIVER_KEY, DB_DEFAULT_DRIVER)); + databaseProperties.put(DB_URL_KEY, configuration.getString(DB_URL_KEY)); + databaseProperties.put(DB_USERNAME_KEY, userName); + databaseProperties.put(DB_PASSWORD_KEY, passWord); + + DatabaseConfig dbConfig = new DatabaseConfig(); + dbConfig.loadFromProperties(databaseProperties); + + dbConfig.setDataSource(DataSourceFactory.create(dbConfig.getName(), dbConfig.getDataSourceConfig())); + + checkInit(dbConfig); + + DatabaseFactory.create(dbConfig); + } + + public static void checkInit(DatabaseConfig dbConfig) { + boolean isInit; + try (Connection conn = dbConfig.getDataSource().getConnection(); + ResultSet rs = conn.getMetaData().getTables(conn.getCatalog(), null, AppResult.RESULT_TABLE_NAME, null)) { + isInit = rs.next(); + } catch (SQLException e) { + throw new OmniAdvisorException(format("Failed to connect to dataSource, %s", e)); + } + + if (!isInit) { + LOG.info("Analyze result table is not exist, creating it"); + dbConfig.setDdlGenerate(true); + dbConfig.setDdlRun(true); + } + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/configuration/OmniAdvisorConfigure.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/configuration/OmniAdvisorConfigure.java new file mode 100644 index 0000000000000000000000000000000000000000..cac34518aab3f23d1a2c1798aedfea720fcdee86 --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/configuration/OmniAdvisorConfigure.java @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.configuration; + +import org.apache.commons.configuration2.PropertiesConfiguration; + +public class OmniAdvisorConfigure { + private static final int DEFAULT_THREAD_COUNT = 3; + private static final String THREAD_COUNT_CONF_KEY = "log.analyzer.thread.count"; + private static final String KERBEROS_PRINCIPAL_KEY = "kerberos.principal"; + private static final String KERBEROS_KEYTAB_FILE_KEY = "kerberos.keytab.file"; + + private final int threadCount; + private String kerberosPrincipal; + private String kerberosKeytabFile; + + public OmniAdvisorConfigure(PropertiesConfiguration configuration) { + this.threadCount = configuration.getInt(THREAD_COUNT_CONF_KEY, DEFAULT_THREAD_COUNT); + this.kerberosPrincipal = configuration.getString(KERBEROS_PRINCIPAL_KEY, null); + this.kerberosKeytabFile = configuration.getString(KERBEROS_KEYTAB_FILE_KEY, null); + } + + public int getThreadCount() { + return threadCount; + } + + public String getKerberosPrincipal() { + return kerberosPrincipal; + } + + public String getKerberosKeytabFile() { + return kerberosKeytabFile; + } + + public void setKerberosPrincipal(String kerberosPrincipal) { + this.kerberosPrincipal = kerberosPrincipal; + } + + public void setKerberosKeytabFile(String kerberosKeytabFile) { + this.kerberosKeytabFile = kerberosKeytabFile; + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/exception/OmniAdvisorException.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/exception/OmniAdvisorException.java new file mode 100644 index 0000000000000000000000000000000000000000..fffa3c10a25df77a89a13b38a68750dbe8c081b7 --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/exception/OmniAdvisorException.java @@ -0,0 +1,30 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.exception; + +public class OmniAdvisorException extends RuntimeException { + public OmniAdvisorException(String message) { + super(message); + } + + public OmniAdvisorException(Throwable throwable) { + super(throwable); + } + + public OmniAdvisorException(String message, Throwable throwable) { + super(message, throwable); + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/executor/AnalysisAction.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/executor/AnalysisAction.java new file mode 100644 index 0000000000000000000000000000000000000000..ff0876db89fad53011fdf6fbd6028fa89a3d1b98 --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/executor/AnalysisAction.java @@ -0,0 +1,99 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.executor; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import com.huawei.boostkit.omniadvisor.OmniAdvisorContext; +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob; +import com.huawei.boostkit.omniadvisor.configuration.OmniAdvisorConfigure; +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import com.huawei.boostkit.omniadvisor.fetcher.Fetcher; +import com.huawei.boostkit.omniadvisor.fetcher.FetcherFactory; +import com.huawei.boostkit.omniadvisor.security.HadoopSecurity; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.security.PrivilegedAction; +import java.util.ArrayList; +import java.util.List; +import java.util.Timer; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +public class AnalysisAction implements PrivilegedAction { + private static final Logger LOG = LoggerFactory.getLogger(AnalysisAction.class); + + private static final int WAIT_INTERVAL = 1000; + + private final HadoopSecurity hadoopSecurity; + private final long startTimeMills; + private final long finishTimeMills; + + private final Object appsLock; + + public AnalysisAction(HadoopSecurity hadoopSecurity,long startTimeMills, long finishTImeMills) { + this.appsLock = new Object(); + this.hadoopSecurity = hadoopSecurity; + this.startTimeMills = startTimeMills; + this.finishTimeMills = finishTImeMills; + } + + @Override + public Void run() { + OmniAdvisorContext context = OmniAdvisorContext.getInstance(); + + FetcherFactory fetcherFactory = context.getFetcherFactory(); + OmniAdvisorConfigure omniAdvisorConfigure = context.getOmniAdvisorConfigure(); + + try { + hadoopSecurity.checkLogin(); + } catch (IOException e) { + LOG.error("Error with hadoop kerberos login", e); + throw new OmniAdvisorException(e); + } + + LOG.info("Fetching analytic job list"); + + List analyticJobs = new ArrayList<>(); + for (Fetcher fetcher : fetcherFactory.getAllFetchers()) { + LOG.info("Fetching jobs from {}", fetcher.getType().getName()); + List fetchedJobs = fetcher.fetchAnalyticJobs(startTimeMills, finishTimeMills); + LOG.info("Fetched {} jobs from {}", fetchedJobs.size(), fetcher.getType().getName()); + analyticJobs.addAll(fetchedJobs); + } + + LOG.info("Fetchers get total {} Jobs", analyticJobs.size()); + + if (!analyticJobs.isEmpty()) { + ThreadFactory factory = new ThreadFactoryBuilder().setNameFormat("omni-tuning-thread-%d").build(); + int executorNum = Integer.min(analyticJobs.size(), omniAdvisorConfigure.getThreadCount()); + int queueSize = Integer.max(analyticJobs.size(), executorNum); + ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(executorNum, executorNum, 0L, + TimeUnit.MILLISECONDS, new ArrayBlockingQueue<>(queueSize), factory); + for (AnalyticJob analyticJob : analyticJobs) { + synchronized (appsLock) { + threadPoolExecutor.submit(new ExecutorJob(analyticJob, fetcherFactory, appsLock)); + } + } + Timer timer = new Timer(); + timer.schedule(new ThreadPoolListener(timer, threadPoolExecutor), WAIT_INTERVAL, WAIT_INTERVAL); + } + return null; + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/executor/ExecutorJob.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/executor/ExecutorJob.java new file mode 100644 index 0000000000000000000000000000000000000000..7d8203e5b55d634509cf35363f2f9c85ddbc23bb --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/executor/ExecutorJob.java @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.executor; + +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob; +import com.huawei.boostkit.omniadvisor.fetcher.Fetcher; +import com.huawei.boostkit.omniadvisor.fetcher.FetcherFactory; +import com.huawei.boostkit.omniadvisor.fetcher.FetcherType; +import com.huawei.boostkit.omniadvisor.models.AppResult; +import io.ebean.DB; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Optional; + +import static com.huawei.boostkit.omniadvisor.utils.MathUtils.SECOND_IN_MS; + +class ExecutorJob implements Runnable { + private static final Logger LOG = LoggerFactory.getLogger(ExecutorJob.class); + + private final AnalyticJob analyticJob; + private final FetcherFactory fetcherFactory; + private final Object appsLock; + + public ExecutorJob(AnalyticJob analyticJob, FetcherFactory fetcherFactory, Object appsLock) { + this.analyticJob = analyticJob; + this.fetcherFactory = fetcherFactory; + this.appsLock = appsLock; + } + + @Override + public void run() { + FetcherType type = analyticJob.getType(); + String appId = analyticJob.getApplicationId(); + + LOG.info("Analyzing {} {}", type.getName(), appId); + + long analysisStartTime = System.currentTimeMillis(); + + Fetcher fetcher = fetcherFactory.getFetcher(type); + + final Optional result = fetcher.analysis(analyticJob); + if (result.isPresent()) { + synchronized (appsLock) { + AppResult analyzeResult = result.get(); + LOG.info("Analysis get result {}", appId); + try { + DB.execute(analyzeResult::save); + } catch (Throwable e) { + LOG.error("Error in saving analyze result, {}", e.getMessage()); + } + } + } else { + LOG.info("Analysis get empty result {}", appId); + } + + long analysisTimeMills = System.currentTimeMillis() - analysisStartTime; + + LOG.info("Finish analysis {} {} using {}s", type, appId, analysisTimeMills / SECOND_IN_MS); + } +} \ No newline at end of file diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/executor/OmniAdvisorRunner.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/executor/OmniAdvisorRunner.java new file mode 100644 index 0000000000000000000000000000000000000000..680d4cf1ba234acca09939508b6c8a72b80dbb51 --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/executor/OmniAdvisorRunner.java @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.executor; + +import com.huawei.boostkit.omniadvisor.OmniAdvisorContext; +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import com.huawei.boostkit.omniadvisor.security.HadoopSecurity; +import org.apache.hadoop.conf.Configuration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +public class OmniAdvisorRunner implements Runnable { + private static final Logger LOG = LoggerFactory.getLogger(OmniAdvisorRunner.class); + + private final long startTimeMills; + private final long finishTimeMills; + + public OmniAdvisorRunner(long startTimeMills, long finishTimeMills) { + this.startTimeMills = startTimeMills; + this.finishTimeMills = finishTimeMills; + } + + @Override + public void run() { + LOG.info("OmniAdvisor has started"); + try { + Configuration hadoopConf = OmniAdvisorContext.getHadoopConfig(); + HadoopSecurity hadoopSecurity = new HadoopSecurity(hadoopConf); + hadoopSecurity.doAs(new AnalysisAction(hadoopSecurity, startTimeMills, finishTimeMills)); + } catch (IOException e) { + LOG.error("failed to analyze jobs", e); + throw new OmniAdvisorException(e); + } + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/executor/ThreadPoolListener.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/executor/ThreadPoolListener.java new file mode 100644 index 0000000000000000000000000000000000000000..2ec2bf2a9321c69d0eb4d60cb91dfa77d722de19 --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/executor/ThreadPoolListener.java @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.executor; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Timer; +import java.util.TimerTask; +import java.util.concurrent.ThreadPoolExecutor; + +public class ThreadPoolListener extends TimerTask { + private static final Logger LOG = LoggerFactory.getLogger(ThreadPoolListener.class); + + private final Timer timer; + private final ThreadPoolExecutor executor; + + public ThreadPoolListener(Timer timer, ThreadPoolExecutor executor) { + this.timer = timer; + this.executor = executor; + } + + @Override + public void run() { + LOG.info("Executor taskCount {}, active count {}, complete count {}, {} left", + executor.getTaskCount(), executor.getActiveCount(), executor.getCompletedTaskCount(), + executor.getTaskCount() - executor.getCompletedTaskCount()); + if (executor.getActiveCount() == 0) { + executor.shutdown(); + timer.cancel(); + } + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/fetcher/Fetcher.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/fetcher/Fetcher.java new file mode 100644 index 0000000000000000000000000000000000000000..89dfd89b87339ebc3fec1d56176aa7b8c3312a29 --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/fetcher/Fetcher.java @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.huawei.boostkit.omniadvisor.fetcher; + +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob; +import com.huawei.boostkit.omniadvisor.models.AppResult; + +import java.util.List; +import java.util.Optional; + +public interface Fetcher { + boolean isEnable(); + + FetcherType getType(); + + List fetchAnalyticJobs(long startTimeMills, long finishedTimeMills); + + Optional analysis(AnalyticJob job); +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/fetcher/FetcherFactory.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/fetcher/FetcherFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..9358b6fd15c355dbb6dcd104976e354b7303a4b4 --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/fetcher/FetcherFactory.java @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.fetcher; + +import com.google.common.collect.ImmutableList; +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import com.huawei.boostkit.omniadvisor.spark.SparkFetcher; +import com.huawei.boostkit.omniadvisor.tez.TezFetcher; +import org.apache.commons.configuration2.PropertiesConfiguration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static java.lang.String.format; + +public class FetcherFactory { + private static final Logger LOG = LoggerFactory.getLogger(FetcherFactory.class); + + private final Map enabledFetchers; + + public FetcherFactory(PropertiesConfiguration configuration) { + Map fetchers = new HashMap<>(); + + // init TEZ fetcher + Fetcher tezFetcher = new TezFetcher(configuration); + if (tezFetcher.isEnable()) { + LOG.info("TEZ Fetcher is enabled."); + fetchers.put(FetcherType.TEZ, tezFetcher); + } + + // init SPARK fetcher + Fetcher sparkFetcher = new SparkFetcher(configuration); + if (sparkFetcher.isEnable()) { + LOG.info("Spark Fetcher is enabled."); + fetchers.put(FetcherType.SPARK, sparkFetcher); + } + + this.enabledFetchers = fetchers; + } + + public Fetcher getFetcher(FetcherType type) { + if (enabledFetchers.containsKey(type)) { + return enabledFetchers.get(type); + } else { + throw new OmniAdvisorException(format("Fetcher [%s] is disabled", type.getName())); + } + } + + public List getAllFetchers() { + return ImmutableList.copyOf(enabledFetchers.values()); + } + + public void addFetcher(FetcherType type, Fetcher fetcher) { + enabledFetchers.put(type, fetcher); + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/fetcher/FetcherType.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/fetcher/FetcherType.java new file mode 100644 index 0000000000000000000000000000000000000000..f4a78e562b153853a2ba7dcd847dd7d31e246c4a --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/fetcher/FetcherType.java @@ -0,0 +1,30 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.fetcher; + +public enum FetcherType { + SPARK("SPARK"), TEZ("TEZ"); + + private final String name; + + FetcherType(String name) { + this.name = name; + } + + public String getName() { + return name; + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/models/AppResult.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/models/AppResult.java new file mode 100644 index 0000000000000000000000000000000000000000..7c91ff2897deac8292fddd72ef677391ed0ca80e --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/models/AppResult.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.models; + +import com.huawei.boostkit.omniadvisor.utils.MathUtils; +import io.ebean.Model; +import io.ebean.annotation.Index; + +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.Id; +import javax.persistence.Table; + +@Entity +@Table(name = AppResult.RESULT_TABLE_NAME) +@Index(name = "yarn_app_result_i1", columnNames = {"application_id"}) +@Index(name = "yarn_app_result_i2", columnNames = {"application_name"}) +public class AppResult extends Model { + private static final long serialVersionUID = 1L; + + public static final long FAILED_JOB_DURATION = MathUtils.DAY_IN_MS; + public static final String RESULT_TABLE_NAME = "yarn_app_result"; + public static final int FAILED_STATUS = 0; + public static final int SUCCEEDED_STATUS = 1; + private static final int APPLICATION_ID_LIMIT = 50; + private static final int APPLICATION_NAME_LIMIT = 100; + private static final int APPLICATION_WORKLOAD_LIMIT = 50; + private static final int JOB_TYPE_LIMIT = 50; + + @Id + @Column(length = APPLICATION_ID_LIMIT, unique = true, nullable = false) + public String applicationId; + + @Column(length = APPLICATION_NAME_LIMIT, nullable = false) + public String applicationName; + + @Column(length = APPLICATION_WORKLOAD_LIMIT) + public String applicationWorkload; + + @Column() + public long startTime; + + @Column() + public long finishTime; + + @Column() + public long durationTime; + + @Column(length = JOB_TYPE_LIMIT) + public String jobType; + + @Column(columnDefinition = "TEXT CHARACTER SET utf8mb4") + public String parameters; + + @Column() + public int executionStatus; + + @Column(columnDefinition = "TEXT CHARACTER SET utf8mb4") + public String query; +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/security/HadoopSecurity.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/security/HadoopSecurity.java new file mode 100644 index 0000000000000000000000000000000000000000..6a6b45ed375809e9d5177b2459ace359ac4c9e74 --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/security/HadoopSecurity.java @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.security; + +import com.huawei.boostkit.omniadvisor.OmniAdvisorContext; +import com.huawei.boostkit.omniadvisor.configuration.OmniAdvisorConfigure; +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.security.UserGroupInformation; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.security.PrivilegedAction; + +public final class HadoopSecurity { + private static final Logger LOG = LoggerFactory.getLogger(HadoopSecurity.class); + + private String keytabFile; + private String principal; + private UserGroupInformation loginUser; + + public HadoopSecurity(Configuration hadoopConf) throws IOException { + OmniAdvisorConfigure configure = OmniAdvisorContext.getInstance().getOmniAdvisorConfigure(); + UserGroupInformation.setConfiguration(hadoopConf); + boolean securityEnabled = UserGroupInformation.isSecurityEnabled(); + if (securityEnabled) { + LOG.info("This cluster is Kerberos enabled."); + boolean login = true; + + principal = configure.getKerberosPrincipal(); + if (principal == null) { + LOG.error("Keytab user not set. Please set keytab_user in the configuration file"); + login = false; + } + + keytabFile = configure.getKerberosKeytabFile(); + if (keytabFile == null) { + LOG.error("Keytab location not set. Please set keytab_location in the configuration file"); + login = false; + } + + if (keytabFile != null && !new File(keytabFile).exists()) { + LOG.error("The keytab file at location [" + keytabFile + "] does not exist."); + login = false; + } + + if (!login) { + throw new OmniAdvisorException("Cannot login. This cluster is security enabled."); + } + } + + this.loginUser = getLoginUser(); + } + + public UserGroupInformation getUGI() throws IOException { + checkLogin(); + return loginUser; + } + + public UserGroupInformation getLoginUser() throws IOException { + LOG.info("No login user. Creating login user"); + LOG.info("Logging with " + principal + " and " + keytabFile); + UserGroupInformation.loginUserFromKeytab(principal, keytabFile); + UserGroupInformation user = UserGroupInformation.getLoginUser(); + LOG.info("Logged in with user " + user); + if (UserGroupInformation.isLoginKeytabBased()) { + LOG.info("Login is keytab based"); + } else { + LOG.info("Login is not keytab based"); + } + return user; + } + + public void checkLogin() throws IOException { + if (loginUser == null) { + loginUser = getLoginUser(); + } else { + loginUser.checkTGTAndReloginFromKeytab(); + } + } + + public void doAs(PrivilegedAction action) throws IOException { + UserGroupInformation ugi = getUGI(); + if (ugi != null) { + ugi.doAs(action); + } + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/TezFetcher.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/TezFetcher.java new file mode 100644 index 0000000000000000000000000000000000000000..0f165dba51672121c1755d6f1d4ef30757fae83b --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/TezFetcher.java @@ -0,0 +1,178 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.tez; + +import com.google.common.annotations.VisibleForTesting; +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob; +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import com.huawei.boostkit.omniadvisor.fetcher.Fetcher; +import com.huawei.boostkit.omniadvisor.fetcher.FetcherType; +import com.huawei.boostkit.omniadvisor.models.AppResult; +import com.huawei.boostkit.omniadvisor.tez.data.TezAnalyticJob; +import com.huawei.boostkit.omniadvisor.tez.data.TezDagIdData; +import com.huawei.boostkit.omniadvisor.tez.utils.TezJsonUtils; +import com.huawei.boostkit.omniadvisor.tez.utils.TezUrlFactory; +import com.huawei.boostkit.omniadvisor.tez.utils.TimelineClient; +import com.huawei.boostkit.omniadvisor.utils.Utils; +import com.sun.jersey.api.client.ClientHandlerException; +import org.apache.commons.configuration2.PropertiesConfiguration; +import org.apache.hadoop.security.authentication.client.AuthenticationException; +import org.apache.hadoop.yarn.api.records.YarnApplicationState; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.huawei.boostkit.omniadvisor.utils.Utils.loadParamsFromConf; + +public class TezFetcher implements Fetcher { + private static final Logger LOG = LoggerFactory.getLogger(TezFetcher.class); + + private static final String TEZ_ENABLE_KEY = "tez.enable"; + private static final String TEZ_WORKLOAD_KEY = "tez.workload"; + private static final String TEZ_TIMELINE_URL_KEY = "tez.timeline.url"; + private static final String TEZ_TIMELINE_TIMEOUT_KEY = "tez.timeline.timeout.ms"; + private static final String DEFAULT_WORKLOAD = "default"; + private static final String DEFAULT_TIMELINE_URL = "http://localhost:8188"; + private static final String HTTPS_PREFIX = "https://"; + private static final int DEFAULT_CONNECTION_TIMEOUT_MS = 6000; + private static final String TEZ_PARAMS_CONF_FILE = "TezParams"; + + private final boolean enable; + private String workload; + private TezJsonUtils tezJsonUtils; + + public TezFetcher(PropertiesConfiguration configuration) { + this.enable = configuration.getBoolean(TEZ_ENABLE_KEY, false); + if (enable) { + String timelineUrl = configuration.getString(TEZ_TIMELINE_URL_KEY, DEFAULT_TIMELINE_URL); + TezUrlFactory tezUrlFactory = new TezUrlFactory(timelineUrl); + this.workload = configuration.getString(TEZ_WORKLOAD_KEY, DEFAULT_WORKLOAD); + int timeout = configuration.getInt(TEZ_TIMELINE_TIMEOUT_KEY, DEFAULT_CONNECTION_TIMEOUT_MS); + boolean useHttps = timelineUrl.startsWith(HTTPS_PREFIX); + this.tezJsonUtils = new TezJsonUtils(tezUrlFactory, useHttps, timeout); + } + } + + @Override + public boolean isEnable() { + if (enable) { + try { + tezJsonUtils.verifyTimeLineServer(); + return true; + } catch (IOException e) { + LOG.error("Connect to timeline server failed {}, TEZ fetcher is disabled", e.getMessage()); + return false; + } + } + return false; + } + + @Override + public FetcherType getType() { + return FetcherType.TEZ; + } + + @Override + public List fetchAnalyticJobs(long startTimeMills, long finishedTimeMills) { + try { + return tezJsonUtils.getApplicationJobs(startTimeMills, finishedTimeMills); + } catch (IOException | AuthenticationException | ClientHandlerException e) { + LOG.error("Fetch applications from timeline server failed.", e); + return Collections.emptyList(); + } + } + + @Override + public Optional analysis(AnalyticJob job) { + if (!(job instanceof TezAnalyticJob)) { + throw new OmniAdvisorException("TezFetcher only support TezAnalyticJob"); + } + TezAnalyticJob tezJob = (TezAnalyticJob) job; + + List dagIds; + try { + dagIds = tezJsonUtils.getDAGIds(job.getApplicationId()); + } catch (IOException | ClientHandlerException e) { + LOG.error("Get dagIds from timeline server failed.", e); + return Optional.empty(); + } + + if (dagIds.isEmpty()) { + LOG.info("There is no dag in application {}, skip it", job.getApplicationId()); + return Optional.empty(); + } + + // If there is more than one dag in application, only analyze the last one + TezDagIdData tezDagId = dagIds.stream().max(TezDagIdData::compareTo).get(); + + return extractAppResult(tezJob, tezDagId); + } + + private Optional extractAppResult(TezAnalyticJob tezJob, TezDagIdData dagIdData) { + LOG.info("Analyzing dag {}", dagIdData.getDagId()); + AppResult appResult = new AppResult(); + Map jobConf; + try { + jobConf = tezJsonUtils.getConfigure(tezJob.getApplicationId()); + appResult.parameters = Utils.parseMapToJsonString(loadParamsFromConf(TEZ_PARAMS_CONF_FILE, jobConf)); + appResult.query = tezJsonUtils.getQueryString(dagIdData.getDagId()); + } catch (IOException e) { + LOG.error("Analyze job failed. ", e); + return Optional.empty(); + } + + appResult.applicationId = tezJob.getApplicationId(); + appResult.applicationName = tezJob.getApplicationName(); + appResult.applicationWorkload = workload; + appResult.jobType = tezJob.getType().getName(); + + if (dagIdData.isComplete()) { + appResult.startTime = dagIdData.getStartTime(); + appResult.finishTime = dagIdData.getEndTime(); + appResult.executionStatus = dagIdData.isSuccess() ? AppResult.SUCCEEDED_STATUS : AppResult.FAILED_STATUS; + appResult.durationTime = dagIdData.isSuccess() ? dagIdData.getDuration() : AppResult.FAILED_JOB_DURATION; + } else { + appResult.startTime = tezJob.getStartTimeMills(); + appResult.finishTime = tezJob.getFinishTimeMills(); + if (tezJob.getState() == YarnApplicationState.KILLED) { + LOG.info("Application {} is killed, regarded as a failed task", tezJob.getApplicationId()); + appResult.executionStatus = AppResult.FAILED_STATUS; + appResult.durationTime = AppResult.FAILED_JOB_DURATION; + } else { + LOG.info("Application {} using input time", tezJob.getApplicationId()); + appResult.executionStatus = AppResult.SUCCEEDED_STATUS; + appResult.durationTime = appResult.finishTime - appResult.startTime; + } + } + + return Optional.of(appResult); + } + + @VisibleForTesting + protected void setTezJsonUtils(TezJsonUtils jsonUtils) { + this.tezJsonUtils = jsonUtils; + } + + @VisibleForTesting + protected void setTimelineClient(TimelineClient timelineClient) { + this.tezJsonUtils.setTimelineClient(timelineClient); + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/data/TezAnalyticJob.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/data/TezAnalyticJob.java new file mode 100644 index 0000000000000000000000000000000000000000..6fcfa9c705d16237d42ff9f0a283913d3146f04e --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/data/TezAnalyticJob.java @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.tez.data; + +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob; +import com.huawei.boostkit.omniadvisor.fetcher.FetcherType; +import org.apache.hadoop.yarn.api.records.YarnApplicationState; + +public class TezAnalyticJob implements AnalyticJob { + private final String applicationId; + private final String applicationName; + private final long startTimeMills; + private final long finishTimeMills; + private final YarnApplicationState state; + + public TezAnalyticJob(String appId, String appName, long startTime, long finishTime, YarnApplicationState state) { + this.applicationId = appId; + this.applicationName = appName; + this.startTimeMills = startTime; + this.finishTimeMills = finishTime; + this.state = state; + } + + @Override + public String getApplicationId() { + return applicationId; + } + + @Override + public FetcherType getType() { + return FetcherType.TEZ; + } + + public String getApplicationName() { + return applicationName; + } + + public long getStartTimeMills() { + return startTimeMills; + } + + public long getFinishTimeMills() { + return finishTimeMills; + } + + public YarnApplicationState getState() { + return state; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (!(other instanceof TezAnalyticJob)) { + return false; + } + + TezAnalyticJob otherJob = (TezAnalyticJob) other; + return this.applicationId.equals(otherJob.applicationId) + && this.applicationName.equals(otherJob.applicationName) + && this.startTimeMills == otherJob.startTimeMills + && this.finishTimeMills == otherJob.finishTimeMills; + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/data/TezDagIdData.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/data/TezDagIdData.java new file mode 100644 index 0000000000000000000000000000000000000000..a47a0f36c97e522cea7e30613233bc4d286bd393 --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/data/TezDagIdData.java @@ -0,0 +1,78 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.tez.data; + +import org.apache.tez.dag.app.dag.DAGState; + +public class TezDagIdData implements Comparable { + private final String dagId; + private final long startTime; + private final long endTime; + private final long duration; + private final DAGState status; + + public TezDagIdData(String dagId, long startTime, long endTime, long duration, DAGState status) { + this.dagId = dagId; + this.startTime = startTime; + this.endTime = endTime; + this.duration = duration; + this.status = status; + } + + public String getDagId() { + return dagId; + } + + public long getStartTime() { + return startTime; + } + + public long getEndTime() { + return endTime; + } + + public long getDuration() { + return duration; + } + + public boolean isComplete() { + return (status == DAGState.SUCCEEDED || + status == DAGState.FAILED || + status == DAGState.KILLED || + status == DAGState.ERROR || + status == DAGState.TERMINATING); + } + + public boolean isSuccess() { + return status == DAGState.SUCCEEDED; + } + + @Override + public int compareTo(TezDagIdData other) { + return Long.compare(this.startTime, other.startTime); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof TezDagIdData)) { + return false; + } + return this.dagId.equals(((TezDagIdData) other).dagId); + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/utils/TezJsonUtils.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/utils/TezJsonUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..43c02a8b7e96d6a4523a0d4e2f867dded63430f7 --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/utils/TezJsonUtils.java @@ -0,0 +1,136 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.tez.utils; + +import com.huawei.boostkit.omniadvisor.OmniAdvisorContext; +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob; +import com.huawei.boostkit.omniadvisor.tez.data.TezAnalyticJob; +import com.huawei.boostkit.omniadvisor.tez.data.TezDagIdData; +import org.apache.hadoop.security.authentication.client.AuthenticationException; +import org.apache.hadoop.yarn.api.records.YarnApplicationState; +import org.apache.hadoop.yarn.server.resourcemanager.webapp.RMWSConsts; +import org.apache.tez.common.ATSConstants; +import org.apache.tez.dag.app.dag.DAGState; +import org.apache.tez.dag.history.utils.DAGUtils; +import org.codehaus.jackson.JsonNode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.URLConnection; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +public class TezJsonUtils { + private static final Logger LOG = LoggerFactory.getLogger(TezJsonUtils.class); + private static final String HIVE_APP_NAME_PREFIX = "HIVE-"; + + private final TezUrlFactory tezUrlFactory; + private TimelineClient timelineClient; + + public TezJsonUtils(TezUrlFactory tezUrlFactory, boolean useHttps, int timeout) { + this.tezUrlFactory = tezUrlFactory; + this.timelineClient = new TimelineClient(OmniAdvisorContext.getHadoopConfig(), useHttps, timeout); + } + + public void verifyTimeLineServer() throws IOException { + URL timeLineUrl = tezUrlFactory.getRootURL(); + URLConnection connection = timeLineUrl.openConnection(); + connection.connect(); + } + + public List getApplicationJobs(long startedTime, long finishedTime) + throws IOException, AuthenticationException { + URL historyUrl = tezUrlFactory.getApplicationHistoryURL(startedTime, finishedTime); + LOG.info("calling REST API at {} to get applications", historyUrl.toString()); + JsonNode rootNode = timelineClient.readJsonNode(historyUrl); + JsonNode apps = rootNode.path("app"); + List analyticJobs = new ArrayList<>(); + for (JsonNode app : apps) { + String appId = app.get(RMWSConsts.APP_ID).getTextValue(); + if (OmniAdvisorContext.getInstance().getFinder().byId(appId) == null) { + String name = getApplicationName(app.get("name").getTextValue()); + String state = app.get("appState").getTextValue(); + TezAnalyticJob tezJob = + new TezAnalyticJob(appId, name, startedTime, finishedTime, YarnApplicationState.valueOf(state)); + analyticJobs.add(tezJob); + } + } + return analyticJobs; + } + + private String getApplicationName(String name) { + if (name.startsWith(HIVE_APP_NAME_PREFIX)) { + return name.substring(HIVE_APP_NAME_PREFIX.length()); + } else { + return name; + } + } + + public List getDAGIds(String applicationId) throws MalformedURLException { + URL dagIdUrl = tezUrlFactory.getDagIdURL(applicationId); + LOG.info("Get DAG ids from REST API at {}", dagIdUrl.toString()); + JsonNode rootNode = timelineClient.readJsonNode(dagIdUrl); + List dagIds = new ArrayList<>(); + + for (JsonNode entity : rootNode.get(ATSConstants.ENTITIES)) { + String dagId = entity.get(ATSConstants.ENTITY).getTextValue(); + long startTime = entity.get(ATSConstants.OTHER_INFO).path(ATSConstants.START_TIME).getLongValue(); + long endTime = entity.get(ATSConstants.OTHER_INFO).path(ATSConstants.FINISH_TIME).getLongValue(); + long duration = entity.get(ATSConstants.OTHER_INFO).path(ATSConstants.TIME_TAKEN).getLongValue(); + DAGState status = + DAGState.valueOf(entity.path(ATSConstants.OTHER_INFO).path(ATSConstants.STATUS).getTextValue()); + dagIds.add(new TezDagIdData(dagId, startTime, endTime, duration, status)); + } + LOG.info("Get {} dags for application {}", dagIds.size(), applicationId); + return dagIds; + } + + public Map getConfigure(String applicationId) throws MalformedURLException { + URL applicationURL = tezUrlFactory.getApplicationURL(applicationId); + LOG.info("Get configuration by calling REST API {}", applicationURL); + JsonNode rootNode = timelineClient.readJsonNode(applicationURL); + JsonNode config = rootNode.path(ATSConstants.OTHER_INFO).path(ATSConstants.CONFIG); + Iterator fieldNames = config.getFieldNames(); + Map params = new HashMap<>(); + while (fieldNames.hasNext()) { + String key = fieldNames.next(); + String value = config.get(key).getTextValue(); + params.put(key, value); + } + return params; + } + + public String getQueryString(String dagId) throws MalformedURLException { + URL dagExtraInfoURL = tezUrlFactory.getDagExtraInfoURL(dagId); + LOG.info("Get query string by calling REST API {}", dagExtraInfoURL); + JsonNode rootNode = timelineClient.readJsonNode(dagExtraInfoURL); + return rootNode.path(ATSConstants.OTHER_INFO) + .path(ATSConstants.DAG_PLAN) + .path(DAGUtils.DAG_CONTEXT_KEY) + .get(ATSConstants.DESCRIPTION) + .getTextValue(); + } + + public void setTimelineClient(TimelineClient timelineClient) { + this.timelineClient = timelineClient; + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/utils/TezUrlFactory.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/utils/TezUrlFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..7980c5720e21ccb9ad5a2ee5a3f9753a29863f76 --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/utils/TezUrlFactory.java @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.tez.utils; + +import org.apache.hadoop.yarn.server.resourcemanager.webapp.RMWSConsts; +import org.apache.tez.common.ATSConstants; +import org.apache.tez.dag.history.logging.EntityTypes; + +import java.net.MalformedURLException; +import java.net.URL; + +import static java.lang.String.format; + +public class TezUrlFactory { + private static final String APPLICATION_TYPE = "TEZ"; + private static final String TEZ_APPLICATION_PREFIX = "tez_"; + + private static final String APPLICATION_HISTORY_URL = "%s/ws/v1/applicationhistory/apps?%s=%s&%s=%s&%s=%s"; + private static final String TIMELINE_BASE_URL = "%s" + ATSConstants.RESOURCE_URI_BASE; + private static final String TIMELINE_ENTITY_URL = TIMELINE_BASE_URL + "/%s/%s"; + private static final String TIMELINE_ENTITY_WITH_FILTER_URL = TIMELINE_BASE_URL + "/%s?primaryFilter=%s:%s"; + + private final String baseUrl; + + public TezUrlFactory(String baseUrl) { + this.baseUrl = baseUrl; + } + + public URL getRootURL() throws MalformedURLException { + return new URL(format(TIMELINE_BASE_URL, baseUrl)); + } + + public URL getApplicationURL(String applicationId) throws MalformedURLException { + return new URL(format(TIMELINE_ENTITY_URL, baseUrl, EntityTypes.TEZ_APPLICATION, + TEZ_APPLICATION_PREFIX + applicationId)); + } + + public URL getDagIdURL(String applicationId) throws MalformedURLException { + return new URL(format(TIMELINE_ENTITY_WITH_FILTER_URL, baseUrl, EntityTypes.TEZ_DAG_ID, + ATSConstants.APPLICATION_ID, applicationId)); + } + + public URL getDagExtraInfoURL(String dagId) throws MalformedURLException { + return new URL(format(TIMELINE_ENTITY_URL, baseUrl, EntityTypes.TEZ_DAG_EXTRA_INFO, dagId)); + } + + public URL getApplicationHistoryURL(long startTime, long finishTime) throws MalformedURLException { + return new URL(format(APPLICATION_HISTORY_URL, baseUrl, RMWSConsts.APPLICATION_TYPES, APPLICATION_TYPE, + RMWSConsts.STARTED_TIME_BEGIN, startTime, RMWSConsts.STARTED_TIME_END, finishTime)); + + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/utils/TimelineClient.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/utils/TimelineClient.java new file mode 100644 index 0000000000000000000000000000000000000000..61c26824f483b7617fd1b91c85561d8f5a09a496 --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/tez/utils/TimelineClient.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.tez.utils; + +import com.google.common.annotations.VisibleForTesting; +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import com.sun.jersey.api.client.Client; +import com.sun.jersey.api.client.ClientResponse; +import com.sun.jersey.api.client.WebResource; +import org.apache.hadoop.conf.Configuration; +import org.apache.tez.dag.api.TezException; +import org.apache.tez.dag.api.client.TimelineReaderFactory; +import org.codehaus.jackson.JsonNode; +import org.codehaus.jackson.map.ObjectMapper; +import org.codehaus.jettison.json.JSONObject; + +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; + +import java.io.IOException; +import java.net.URL; + +import static java.lang.String.format; + +public class TimelineClient implements AutoCloseable { + private static final ObjectMapper MAPPER = new ObjectMapper(); + + private Client httpClient; + + public TimelineClient(Configuration conf, boolean useHttps, int connTimeout) { + try { + this.httpClient = TimelineReaderFactory.getTimelineReaderStrategy(conf, useHttps, connTimeout).getHttpClient(); + } catch (TezException | IOException e) { + throw new OmniAdvisorException(e); + } + } + + public JsonNode readJsonNode(URL url) { + WebResource resource = httpClient.resource(url.toString()); + ClientResponse response = resource.accept(MediaType.APPLICATION_JSON_TYPE) + .type(MediaType.APPLICATION_JSON_TYPE) + .get(ClientResponse.class); + + if (response.getStatus() == Response.Status.OK.getStatusCode()) { + try { + return MAPPER.readTree(response.getEntity(JSONObject.class).toString()); + } catch (IOException e) { + throw new OmniAdvisorException(e); + } + } else { + throw new OmniAdvisorException(format("Failed to get data from %s", url)); + } + } + + @VisibleForTesting + protected void setClient(Client client) { + this.httpClient = client; + } + + @Override + public void close() { + httpClient.destroy(); + } +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/utils/MathUtils.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/utils/MathUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..f1be9b790db2decfc4174270a15f199cd19b1ab1 --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/utils/MathUtils.java @@ -0,0 +1,25 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.utils; + +public final class MathUtils { + public static final long SECOND_IN_MS = 1000L; + public static final long MINUTE_IN_MS = 60L * SECOND_IN_MS; + public static final long HOUR_IN_MS = 60L * MINUTE_IN_MS; + public static final long DAY_IN_MS = 24 * HOUR_IN_MS; + + private MathUtils() {} +} diff --git a/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/utils/Utils.java b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/utils/Utils.java new file mode 100644 index 0000000000000000000000000000000000000000..5bcdb2b089d444029f88416f7f542308a0cc837f --- /dev/null +++ b/omniadvisor/src/main/java/com/huawei/boostkit/omniadvisor/utils/Utils.java @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.utils; + +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import net.minidev.json.JSONObject; + +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public final class Utils { + private Utils() {} + + public static Map loadParamsFromConf(String configParamName, Map conf) { + URL fileURL = requireNonNull(Thread.currentThread().getContextClassLoader().getResource(configParamName), + format("Tez param config file %s is not found", configParamName)); + Map params = new HashMap<>(); + try (BufferedReader br = new BufferedReader( + new InputStreamReader(new FileInputStream(fileURL.getPath()), StandardCharsets.UTF_8))) { + String line; + while ((line = br.readLine()) != null) { + params.put(line, conf.getOrDefault(line, "")); + } + } catch (IOException e) { + throw new OmniAdvisorException(e); + } + return params; + } + + public static String parseMapToJsonString(Map map) { + JSONObject json = new JSONObject(); + json.putAll(map); + return json.toJSONString(); + } +} diff --git a/omniadvisor/src/main/resources/conf/SparkParams b/omniadvisor/src/main/resources/conf/SparkParams new file mode 100644 index 0000000000000000000000000000000000000000..6fffe44a0073b39e7ef70308aefe7f4be907cc83 --- /dev/null +++ b/omniadvisor/src/main/resources/conf/SparkParams @@ -0,0 +1,15 @@ +spark.executor.memory +spark.executor.cores +spark.executor.instances +spark.driver.cores +spark.driver.memory +spark.memory.offHeap.size +spark.broadcast.blockSize +spark.sql.shuffle.partitions +spark.executor.memoryOverhead +spark.memory.fraction +spark.memory.storageFraction +spark.sql.autoBroadcastJoinThreshold +spark.sql.join.preferSortMergeJoin +spark.sql.adaptive.enabled +spark.sql.adaptive.skewJoin.enabled \ No newline at end of file diff --git a/omniadvisor/src/main/resources/conf/TezParams b/omniadvisor/src/main/resources/conf/TezParams new file mode 100644 index 0000000000000000000000000000000000000000..e54c8dc2d436f239c4349c32a59f4afedf38cc29 --- /dev/null +++ b/omniadvisor/src/main/resources/conf/TezParams @@ -0,0 +1,14 @@ +hive.exec.reducers.max +hive.tez.container.size +hive.exec.parallel.thread.number +hive.cbo.enable +hive.exec.parallel +hive.tez.auto.reducer.parallelism +tez.runtime.io.sort.mb +tez.am.resource.memory.mb +tez.am.resource.cpu.vcores +tez.task.resource.memory.mb +tez.task.resource.cpu.vcores +tez.runtime.sort.spill.percent +tez.runtime.compress +tez.am.speculation.enabled \ No newline at end of file diff --git a/omniadvisor/src/main/resources/conf/log4j.properties b/omniadvisor/src/main/resources/conf/log4j.properties new file mode 100644 index 0000000000000000000000000000000000000000..f9ab97535daaf1053250c2b8296b18e116c84cee --- /dev/null +++ b/omniadvisor/src/main/resources/conf/log4j.properties @@ -0,0 +1,5 @@ +log4j.rootLogger=INFO, console +# console log +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yyyy-MM-dd HH\:mm\:ss} %p [%c] %m%n \ No newline at end of file diff --git a/omniadvisor/src/main/resources/conf/omniAdvisorLogAnalyzer.properties b/omniadvisor/src/main/resources/conf/omniAdvisorLogAnalyzer.properties new file mode 100644 index 0000000000000000000000000000000000000000..d39eba9ff9316d496e462b0468f8acbfe7edea80 --- /dev/null +++ b/omniadvisor/src/main/resources/conf/omniAdvisorLogAnalyzer.properties @@ -0,0 +1,15 @@ +log.analyzer.thread.count=3 + +datasource.db.driver=com.mysql.cj.jdbc.Driver +datasource.db.url=url + +spark.enable=true +spark.workload=default +spark.eventLogs.mode=rest +spark.rest.url=http://server1:18080 +spark.timeout.seconds=30 + +tez.enable=true +tez.workload=default +tez.timeline.url=http://server1:8188 +tez.timeline.timeout.ms=6000 diff --git a/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/SparkFetcher.scala b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/SparkFetcher.scala new file mode 100644 index 0000000000000000000000000000000000000000..8694d5b58fefb53a60f6fa853d2b26b50fe4ae87 --- /dev/null +++ b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/SparkFetcher.scala @@ -0,0 +1,101 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.spark + +import com.huawei.boostkit.omniadvisor.OmniAdvisorContext +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob +import com.huawei.boostkit.omniadvisor.fetcher.{Fetcher, FetcherType} +import com.huawei.boostkit.omniadvisor.models.AppResult +import com.huawei.boostkit.omniadvisor.spark.client.{SparkEventClient, SparkLogClient, SparkRestClient} +import com.huawei.boostkit.omniadvisor.spark.config.SparkFetcherConfigure +import com.huawei.boostkit.omniadvisor.spark.utils.SparkUtils +import org.apache.commons.configuration2.PropertiesConfiguration +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.spark.SparkConf +import org.slf4j.{Logger, LoggerFactory} + +import java.util +import java.util.Optional +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration.{Duration, SECONDS} +import scala.concurrent.{Await, Future} +import scala.util.{Failure, Success, Try} + +class SparkFetcher(configure: PropertiesConfiguration) + extends Fetcher +{ + private val LOG: Logger = LoggerFactory.getLogger(classOf[SparkFetcher]) + + val sparkFetcherConfig = new SparkFetcherConfigure(configure) + + lazy val hadoopConfigure: Configuration = OmniAdvisorContext.getHadoopConfig + + lazy val sparkConf: SparkConf = { + val sparkConf = new SparkConf() + SparkUtils.getDefaultPropertiesFile() match { + case Some(fileName) => sparkConf.setAll(SparkUtils.getPropertiesFromFile(fileName)) + case None => LOG.warn("Can't find Spark conf, use default config, Please set SPARK_HOME or SPARK_CONF_DIR") + } + sparkConf + } + + lazy val sparkClient: SparkEventClient = { + if (sparkFetcherConfig.isRestMode) { + new SparkRestClient(sparkFetcherConfig.restUrl, sparkFetcherConfig.timeoutSeconds, sparkConf, + sparkFetcherConfig.workload) + } else { + new SparkLogClient(hadoopConfigure, sparkConf, sparkFetcherConfig.logDirectory, sparkFetcherConfig.workload, + sparkFetcherConfig.maxLogSizeInMB * FileUtils.ONE_MB) + } + } + + override def isEnable: Boolean = sparkFetcherConfig.enable + + override def analysis(job: AnalyticJob): Optional[AppResult] = { + val appId = job.getApplicationId + LOG.info(s"Fetching data for ${appId}") + val result = Try { + Await.result(doAnalysisApplication(job), Duration(sparkFetcherConfig.timeoutSeconds, SECONDS)) + }.transform( + data => { + LOG.info(s"Succeed fetching data for ${appId}") + Success(data) + }, + e => { + LOG.error(s"Failed fetching data for ${appId}, Exception Message is ${e.getMessage}") + Failure(e) + }) + result match { + case Success(data) => Optional.of(data) + case Failure(e) => Optional.empty() + } + } + + private def doAnalysisApplication(job: AnalyticJob): Future[AppResult] = { + Future { + sparkClient.fetchAnalyticResult(job) + } + } + + override def getType: FetcherType = FetcherType.SPARK + + override def fetchAnalyticJobs(startTimeMills: Long, finishedTimeMills: Long): util.List[AnalyticJob] = { + val jobs: util.List[AnalyticJob] = new util.ArrayList[AnalyticJob]() + sparkClient.fetchAnalyticJobs(startTimeMills, finishedTimeMills).foreach(job => jobs.add(job)) + jobs + } +} diff --git a/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/client/SparkEventClient.scala b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/client/SparkEventClient.scala new file mode 100644 index 0000000000000000000000000000000000000000..e67d3ffbd0c617e49e3eb6edf1aa2af26dff4f9e --- /dev/null +++ b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/client/SparkEventClient.scala @@ -0,0 +1,25 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.spark.client + +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob +import com.huawei.boostkit.omniadvisor.models.AppResult + +trait SparkEventClient { + def fetchAnalyticJobs(startTimeMills: Long, finishedTimeMills: Long): List[AnalyticJob] + + def fetchAnalyticResult(job: AnalyticJob): AppResult +} diff --git a/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/client/SparkLogClient.scala b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/client/SparkLogClient.scala new file mode 100644 index 0000000000000000000000000000000000000000..e125c3f712e62e61301ae44f6c46d4a13bf1292a --- /dev/null +++ b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/client/SparkLogClient.scala @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.spark.client + +import com.huawei.boostkit.omniadvisor.OmniAdvisorContext +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob +import com.huawei.boostkit.omniadvisor.models.AppResult +import com.huawei.boostkit.omniadvisor.spark.data.SparkLogAnalyticJob +import com.huawei.boostkit.omniadvisor.spark.utils.SparkUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.SparkApplicationDataExtractor.extractAppResultFromAppStatusStore +import org.apache.spark.SparkConf +import org.apache.spark.SparkDataCollection + +class SparkLogClient(hadoopConfiguration: Configuration, sparkConf: SparkConf, eventLogUri: String, + workload: String, maxFileSize: Long) + extends SparkEventClient { + + override def fetchAnalyticJobs(startTimeMills: Long, finishedTimeMills: Long): List[AnalyticJob] = { + SparkUtils.findApplicationFiles(hadoopConfiguration, eventLogUri, startTimeMills, finishedTimeMills, maxFileSize) + .map(file => new SparkLogAnalyticJob(SparkUtils.getApplicationIdFromFile(file), file)) + .filter(job => OmniAdvisorContext.getInstance().getFinder.byId(job.getApplicationId) == null) + } + + override def fetchAnalyticResult(job: AnalyticJob): AppResult = { + require(job.isInstanceOf[SparkLogAnalyticJob], "Require SparkLogAnalyticJob") + val logJob = job.asInstanceOf[SparkLogAnalyticJob] + val path = new Path(logJob.getFilePath) + val compressCodec = SparkUtils.compressionCodecForLogName(sparkConf, path.getName) + val dataCollection = new SparkDataCollection + + SparkUtils.withEventLog( + FileSystem.get(path.toUri, hadoopConfiguration), path, compressCodec) { in => + dataCollection.replayEventLogs(in, path.toString) + } + + val appInfo = dataCollection.appInfo + val jobData = dataCollection.jobData + val environmentInfo = dataCollection.environmentInfo + + extractAppResultFromAppStatusStore(appInfo, workload, environmentInfo, jobData) + } +} diff --git a/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/client/SparkRestClient.scala b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/client/SparkRestClient.scala new file mode 100644 index 0000000000000000000000000000000000000000..dcdcf8e6ebc932cb38693c7f1aa6e0a494f5d18f --- /dev/null +++ b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/client/SparkRestClient.scala @@ -0,0 +1,189 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.spark.client + +import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} +import com.fasterxml.jackson.module.scala.{DefaultScalaModule, ScalaObjectMapper} +import com.huawei.boostkit.omniadvisor.OmniAdvisorContext +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException +import com.huawei.boostkit.omniadvisor.models.AppResult +import com.huawei.boostkit.omniadvisor.spark.data.SparkRestAnalyticJob +import com.huawei.boostkit.omniadvisor.spark.utils.SparkUtils +import org.apache.spark.SparkConf +import org.apache.spark.SparkDataCollection +import org.apache.spark.status.api.v1.ApplicationInfo +import org.glassfish.jersey.client.ClientProperties +import org.slf4j.{Logger, LoggerFactory} + +import java.io.{BufferedInputStream, InputStream} +import java.net.URI +import java.text.SimpleDateFormat +import java.util.{Calendar, Date, SimpleTimeZone} +import java.util.zip.ZipInputStream +import javax.ws.rs.client.{Client, ClientBuilder, WebTarget} +import javax.ws.rs.core.MediaType +import scala.collection.mutable.ListBuffer +import scala.concurrent.duration.{Duration, FiniteDuration, SECONDS} +import scala.util.control.NonFatal + +class SparkRestClient(historyUri: String, timeoutSeconds: Int, sparkConf: SparkConf, workload: String) + extends SparkEventClient { + private val LOG: Logger = LoggerFactory.getLogger(classOf[SparkRestClient]) + + private val historyServerUri: URI = { + val baseUri: URI = { + if (historyUri.contains("http://")) { + new URI(historyUri) + } else { + new URI(s"http://${historyUri}") + } + } + require(baseUri.getPath == "") + baseUri + } + + val timeout: FiniteDuration = Duration(timeoutSeconds, SECONDS) + val API_V1_MOUNT_PATH = "api/v1" + val IN_PROGRESS = ".inprogress" + + val sparkRestObjectMapper: ObjectMapper with ScalaObjectMapper = { + val dateFormat = { + val formatter = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'") + val cal = Calendar.getInstance(new SimpleTimeZone(0, "GMT")) + formatter.setCalendar(cal) + formatter + } + + val objectMapper = new ObjectMapper() with ScalaObjectMapper + objectMapper.setDateFormat(dateFormat) + objectMapper.registerModule(DefaultScalaModule) + objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + objectMapper + } + + private val client: Client = ClientBuilder.newClient() + + private var apiTarget: WebTarget = client.property(ClientProperties.CONNECT_TIMEOUT, timeout.toMillis.toInt) + .property(ClientProperties.READ_TIMEOUT, timeout.toMillis.toInt) + .target(historyServerUri) + .path(API_V1_MOUNT_PATH) + + protected def setApiTarget(apiTarget: WebTarget): Unit = { + this.apiTarget = apiTarget + } + + override def fetchAnalyticJobs(startTimeMills: Long, finishedTimeMills: Long): List[AnalyticJob] = { + val minDate = sparkRestObjectMapper.getDateFormat.format(new Date(startTimeMills)) + val maxDate = sparkRestObjectMapper.getDateFormat.format(new Date(finishedTimeMills)) + val appTarget = apiTarget.path("applications").queryParam("minDate", minDate).queryParam("maxDate", maxDate) + + try { + LOG.info(s"calling REST API at ${appTarget.getUri}") + val applications = getApplications(appTarget, sparkRestObjectMapper.readValue[Seq[ApplicationInfo]]) + .filter(job => OmniAdvisorContext.getInstance().getFinder.byId(job.id) == null) + val analyticJobs = new ListBuffer[AnalyticJob]() + for (appInfo <- applications) { + val attempts = appInfo.attempts + if (attempts.isEmpty) { + LOG.info("application {} attempt is empty, skip it", appInfo.id) + } else { + analyticJobs += new SparkRestAnalyticJob(appInfo.id) + } + } + analyticJobs.toList + } catch { + case NonFatal(e) => + LOG.error(s"error reading jobData ${appTarget.getUri}. Exception Message = ${e}") + throw new OmniAdvisorException(e) + } + } + + override def fetchAnalyticResult(job: AnalyticJob): AppResult = { + require(job.isInstanceOf[SparkRestAnalyticJob], "Require SparkRestAnalyticJob") + val sparkJob = job.asInstanceOf[SparkRestAnalyticJob] + val attemptTarget = getApplicationMetaData(sparkJob.getApplicationId) + val logTarget = attemptTarget.path("logs") + LOG.info(s"creating SparkApplication by calling REST API at ${logTarget.getUri} to get eventLogs") + resource.managed { + getApplicationLogs(logTarget) + }.acquireAndGet{ zipInputStream => + getLogInputStream(zipInputStream, logTarget) match { + case (None, _) => + throw new OmniAdvisorException(s"Failed to read log for application ${sparkJob.getApplicationId}") + case (Some(inputStream), fileName) => + val dataCollection = new SparkDataCollection() + dataCollection.replayEventLogs(inputStream, fileName) + dataCollection.getAppResult(workload) + } + } + } + + def getApplications[T](webTarget: WebTarget, converter: String => T): T = { + converter(webTarget.request(MediaType.APPLICATION_JSON).get(classOf[String])) + } + + private def getApplicationMetaData(appId: String): WebTarget = { + val appTarget = apiTarget.path(s"applications/${appId}") + LOG.info(s"calling REST API at ${appTarget.getUri}") + + val applicationInfo = getApplicationInfo(appTarget) + + val lastAttemptId = applicationInfo.attempts.maxBy{ + _.startTime + }.attemptId + lastAttemptId.map(appTarget.path).getOrElse(appTarget) + } + + private def getApplicationInfo(appTarget: WebTarget): ApplicationInfo = { + try { + getApplications(appTarget, sparkRestObjectMapper.readValue[ApplicationInfo]) + } catch { + case NonFatal(e) => + LOG.error(s"error reading applicationInfo ${appTarget.getUri}. Exception Message = ${e.getMessage}") + throw e + } + } + + private def getApplicationLogs(logTarget: WebTarget): ZipInputStream = { + try { + val inputStream = logTarget.request(MediaType.APPLICATION_OCTET_STREAM) + .get(classOf[java.io.InputStream]) + new ZipInputStream(new BufferedInputStream(inputStream)) + }catch { + case NonFatal(e) => + LOG.error(s"error reading logs ${logTarget.getUri}. Exception Message = ${e.getMessage}") + throw e + } + } + + private def getLogInputStream(zis: ZipInputStream, attemptTarget: WebTarget): (Option[InputStream], String) = { + val entry = zis.getNextEntry + if (entry == null) { + LOG.warn(s"failed to resolve log for ${attemptTarget.getUri}") + (None, "") + } else { + val entryName = entry.getName + if (entryName.equals(IN_PROGRESS)) { + throw new OmniAdvisorException(s"Application for the log ${entryName} has not finished yes.") + } + val codec = SparkUtils.compressionCodecForLogName(sparkConf, entryName) + (Some(codec.map{ + _.compressedInputStream(zis) + }.getOrElse(zis)), entryName) + } + } +} diff --git a/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/config/SparkFetcherConfigure.scala b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/config/SparkFetcherConfigure.scala new file mode 100644 index 0000000000000000000000000000000000000000..f9563b8d29cd84c9e1f2c0249160ffb0498db06e --- /dev/null +++ b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/config/SparkFetcherConfigure.scala @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.spark.config + +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException +import org.apache.commons.configuration2.PropertiesConfiguration + +class SparkFetcherConfigure(propertiesConfiguration: PropertiesConfiguration) { + val enableKey = "spark.enable" + val eventLogModeKey = "spark.eventLogs.mode" + val workloadKey = "spark.workload" + val defaultWorkload = "default" + val restEventLogMode = "rest" + val restUrlKey = "spark.rest.url" + val defaultRestUrl = "http://localhost:18080" + val timeoutKey = "spark.timeout.seconds" + val logEventLogMode = "log" + val logDirectoryKey = "spark.log.directory" + val maxLogFileSizeInMBKey = "spark.log.maxSize.mb" + + val defaultTimeoutSeconds = 30 + val defaultMaxLogSize = 500 + + val enable: Boolean = propertiesConfiguration.getBoolean(enableKey, false) + val mode: String = propertiesConfiguration.getString(eventLogModeKey) + val restUrl: String = propertiesConfiguration.getString(restUrlKey, defaultRestUrl) + val timeoutSeconds: Int = propertiesConfiguration.getInt(timeoutKey, defaultTimeoutSeconds) + val logDirectory: String = propertiesConfiguration.getString(logDirectoryKey, "") + val maxLogSizeInMB: Int = propertiesConfiguration.getInt(maxLogFileSizeInMBKey, defaultMaxLogSize) + val workload: String = propertiesConfiguration.getString(workloadKey, defaultWorkload) + + def isRestMode: Boolean = { + if (mode.equals(restEventLogMode)) { + true + } else if (mode.equals(logEventLogMode)) { + false + } else { + throw new OmniAdvisorException(s"Unknown event log mode ${mode}") + } + } +} diff --git a/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/data/SparkLogAnalyticJob.scala b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/data/SparkLogAnalyticJob.scala new file mode 100644 index 0000000000000000000000000000000000000000..6eb085de7439d99ac42345df0bd7f7637d53031e --- /dev/null +++ b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/data/SparkLogAnalyticJob.scala @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.spark.data + +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob +import com.huawei.boostkit.omniadvisor.fetcher.FetcherType + +class SparkLogAnalyticJob(applicationId: String, filePath: String) extends AnalyticJob { + override def getApplicationId: String = applicationId + + override def getType: FetcherType = FetcherType.SPARK + + def getFilePath:String = filePath +} diff --git a/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/data/SparkRestAnalyticJob.scala b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/data/SparkRestAnalyticJob.scala new file mode 100644 index 0000000000000000000000000000000000000000..6e73f5c0dade1a3af9e0f222ad0d0472f24806b7 --- /dev/null +++ b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/data/SparkRestAnalyticJob.scala @@ -0,0 +1,25 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.spark.data + +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob +import com.huawei.boostkit.omniadvisor.fetcher.FetcherType + +class SparkRestAnalyticJob (applicationId: String) extends AnalyticJob { + override def getApplicationId: String = applicationId + + override def getType: FetcherType = FetcherType.SPARK +} diff --git a/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/utils/ScalaUtils.scala b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/utils/ScalaUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..e18f446a0aed26d5256b0fd30bdbdfc1966d29d8 --- /dev/null +++ b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/utils/ScalaUtils.scala @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.spark.utils + +import com.alibaba.fastjson.JSONObject +import org.apache.spark.JobExecutionStatus +import org.apache.spark.status.api.v1.JobData + +object ScalaUtils { + def parseMapToJsonString(map: Map[String, String]): String = { + val json = new JSONObject + map.foreach(m => { + json.put(m._1, m._2) + }) + json.toJSONString + } + + def checkSuccess(jobs: Seq[JobData]): Boolean = { + !jobs.exists(_.status.equals(JobExecutionStatus.FAILED)) + } +} diff --git a/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/utils/SparkUtils.scala b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/utils/SparkUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..b7a1f725896d0c21cc544c360f81dc9034e42f36 --- /dev/null +++ b/omniadvisor/src/main/scala/com/huawei/boostkit/omniadvisor/spark/utils/SparkUtils.scala @@ -0,0 +1,124 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.spark.utils + +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.SparkConf +import org.apache.spark.io.{CompressionCodec, LZ4CompressionCodec, LZFCompressionCodec, SnappyCompressionCodec, ZStdCompressionCodec} + +import java.io.{BufferedInputStream, File, FileInputStream, FileNotFoundException, InputStream} +import java.net.URI +import java.util.Properties +import scala.collection.JavaConverters.asScalaSetConverter +import scala.collection.mutable +import scala.tools.jline_embedded.internal.InputStreamReader + +object SparkUtils { + def findApplicationFiles(hadoopConfiguration: Configuration, eventLogDir: String, startTimeMills: Long, + finishTimeMills: Long, maxFileSize: Long): List[String] = { + val uri = new URI(eventLogDir) + val fs = FileSystem.get(uri, hadoopConfiguration) + val eventLogDirPath: Path = new Path(eventLogDir) + if (fs.exists(eventLogDirPath) && fs.getFileStatus(eventLogDirPath).isDirectory) { + fs.listStatus(eventLogDirPath).filter(status => { + val fileSize = status.getLen + val modifyTime = status.getModificationTime + modifyTime >= startTimeMills && modifyTime <= finishTimeMills && fileSize <= maxFileSize + }).map { status => status.getPath.toString }.toList + } else { + throw new OmniAdvisorException("eventLog path is not exist or not a Directory") + } + } + + private val IN_PROGRESS = ".inprogress" + + private val compressionCodecClassNamesByShortName = Map( + "lz4" -> classOf[LZ4CompressionCodec].getName, + "lzf" -> classOf[LZFCompressionCodec].getName, + "snappy" -> classOf[SnappyCompressionCodec].getName, + "zstd" -> classOf[ZStdCompressionCodec].getName) + + private val compressionCodecMap = mutable.HashMap.empty[String, CompressionCodec] + + private def loadCompressionCodec(conf: SparkConf, codecName: String): CompressionCodec = { + val codecClass = compressionCodecClassNamesByShortName.getOrElse(codecName.toLowerCase, codecName) + val classLoader = Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader) + val codec = try { + val constructor = Class.forName(codecClass, true, classLoader).getConstructor(classOf[SparkConf]) + Some(constructor.newInstance(conf).asInstanceOf[CompressionCodec]) + } catch { + case _: ClassNotFoundException => None + case _: IllegalArgumentException => None + } + codec.getOrElse(throw new IllegalArgumentException(s"Codec [$codecName] is not available.")) + } + + def compressionCodecForLogName(conf: SparkConf, logName: String): Option[CompressionCodec] = { + val logBaseName = logName.stripSuffix(IN_PROGRESS) + logBaseName.split("\\.").tail.lastOption.map { + codecName => + compressionCodecMap.getOrElseUpdate(codecName, loadCompressionCodec(conf, codecName)) + } + } + + def getApplicationIdFromFile(file: String): String = { + val fileName = new Path(file).getName + val logBaseName = fileName.stripSuffix(IN_PROGRESS) + logBaseName.split("\\.").apply(0) + } + + def withEventLog[T](fs: FileSystem, path: Path, codec: Option[CompressionCodec])(f: InputStream => T): T = { + resource.managed { openEventLog(path, fs)} + .map { in => codec.map { _.compressedInputStream(in) }.getOrElse(in) } + .acquireAndGet(f) + } + + private def openEventLog(logPath: Path, fs: FileSystem): InputStream = { + if (!fs.exists(logPath)) { + throw new FileNotFoundException(s"File ${logPath} does not exist.") + } + + new BufferedInputStream(fs.open(logPath)) + } + + def defaultEnv: Map[String, String] = sys.env + + def getDefaultPropertiesFile(env: Map[String, String] = defaultEnv): Option[String] = { + env.get("SPARK_CONF_DIR").orElse(env.get("SPARK_HOME").map { + t => s"$t${File.separator}conf"}) + .map {t => new File(s"$t${File.separator}spark-defaults.conf")} + .filter(_.isFile) + .map(_.getAbsolutePath) + } + + def getPropertiesFromFile(fileName: String): Map[String, String] = { + val file = new File(fileName) + require(file.exists(), s"Properties file $file does not exist") + require(file.isFile, s"Properties file $file is not a normal file") + + val inReader = new InputStreamReader(new FileInputStream(file), "UTF-8") + try { + val properties = new Properties() + properties.load(inReader) + properties.stringPropertyNames().asScala.map( + k => (k, properties.getProperty(k).trim)).toMap + } finally { + inReader.close() + } + } +} diff --git a/omniadvisor/src/main/scala/org/apache/spark/SparkApplicationDataExtractor.scala b/omniadvisor/src/main/scala/org/apache/spark/SparkApplicationDataExtractor.scala new file mode 100644 index 0000000000000000000000000000000000000000..d5b2b598ffb405559685b9ee4f3265b2c0f6915c --- /dev/null +++ b/omniadvisor/src/main/scala/org/apache/spark/SparkApplicationDataExtractor.scala @@ -0,0 +1,107 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import com.huawei.boostkit.omniadvisor.fetcher.FetcherType +import com.huawei.boostkit.omniadvisor.models.AppResult +import com.huawei.boostkit.omniadvisor.spark.utils.ScalaUtils.{checkSuccess, parseMapToJsonString} +import com.nimbusds.jose.util.StandardCharset +import org.apache.spark.status.api.v1._ +import org.slf4j.{Logger, LoggerFactory} + +import scala.collection.mutable +import scala.io.{BufferedSource, Source} + +object SparkApplicationDataExtractor { + val LOG: Logger = LoggerFactory.getLogger(SparkApplicationDataExtractor.getClass) + + val SPARK_REQUIRED_PARAMS_FILE = "SparkParams" + + def extractAppResultFromAppStatusStore(appInfo: ApplicationInfo, + workload: String, + environmentInfo: ApplicationEnvironmentInfo, + jobsList: Seq[JobData]): AppResult = { + val appResult = new AppResult + appResult.applicationId = appInfo.id + appResult.applicationName = appInfo.name + appResult.jobType = FetcherType.SPARK.getName + appResult.applicationWorkload = workload + + val attempt: ApplicationAttemptInfo = lastAttempt(appInfo) + appResult.startTime = attempt.startTime.getTime + appResult.finishTime = attempt.endTime.getTime + + val configurations: Map[String, String] = extractAppConfigurations(environmentInfo) + appResult.parameters = parseMapToJsonString(extractRequiredConfiguration(configurations)) + + if (!attempt.completed) { + // In this case, the task is killed, consider as a failed task + appResult.executionStatus = AppResult.FAILED_STATUS + appResult.durationTime = AppResult.FAILED_JOB_DURATION + appResult.query = "" + } else { + if (jobsList.nonEmpty) { + val query: Option[String] = jobsList.maxBy(job => job.jobId).description + + if (checkSuccess(jobsList)) { + appResult.executionStatus = AppResult.SUCCEEDED_STATUS + appResult.durationTime = attempt.duration + appResult.query = query.getOrElse("") + } else { + appResult.executionStatus = AppResult.FAILED_STATUS + appResult.durationTime = AppResult.FAILED_JOB_DURATION + appResult.query = "" + } + } else { + appResult.query = "" + appResult.executionStatus = AppResult.FAILED_STATUS + appResult.durationTime = AppResult.FAILED_JOB_DURATION + } + } + + appResult + } + + private def extractRequiredConfiguration(sparkConfigure: Map[String, String]): Map[String, String] = { + var sparkParamsFile: BufferedSource = null + try { + sparkParamsFile = Source.fromFile(Thread.currentThread().getContextClassLoader + .getResource(SPARK_REQUIRED_PARAMS_FILE).getPath, StandardCharset.UTF_8.name) + val params: Iterator[String] = sparkParamsFile.getLines() + val requiredParams = new mutable.HashMap[String, String]() + for (param <- params) { + val paramRequired = param.trim + if (paramRequired.nonEmpty) { + requiredParams.put(paramRequired, sparkConfigure.getOrElse(paramRequired, "")) + } + } + requiredParams.toMap[String, String] + } finally { + if (sparkParamsFile.nonEmpty) { + sparkParamsFile.close + } + } + } + + private def extractAppConfigurations(environmentInfo: ApplicationEnvironmentInfo): Map[String, String] = { + environmentInfo.sparkProperties.toMap + } + + def lastAttempt(applicationInfo: ApplicationInfo): ApplicationAttemptInfo = { + require(applicationInfo.attempts.nonEmpty) + applicationInfo.attempts.last + } +} diff --git a/omniadvisor/src/main/scala/org/apache/spark/SparkDataCollection.scala b/omniadvisor/src/main/scala/org/apache/spark/SparkDataCollection.scala new file mode 100644 index 0000000000000000000000000000000000000000..d738bd3d2c49d75bc246b08f4bcd0a4faed2a09f --- /dev/null +++ b/omniadvisor/src/main/scala/org/apache/spark/SparkDataCollection.scala @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import com.huawei.boostkit.omniadvisor.models.AppResult +import org.apache.spark.status.api.v1 +import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} +import org.apache.spark.internal.config.Status.ASYNC_TRACKING_ENABLED +import org.apache.spark.scheduler.ReplayListenerBus +import org.apache.spark.status.{AppStatusListener, AppStatusStore, ElementTrackingStore} +import org.apache.spark.util.Utils +import org.slf4j.{Logger, LoggerFactory} + +import java.io.InputStream + +class SparkDataCollection { + val LOG: Logger = LoggerFactory.getLogger(classOf[SparkDataCollection]) + + private val conf = new SparkConf + + var environmentInfo: v1.ApplicationEnvironmentInfo = _ + var jobData: Seq[v1.JobData] = _ + var appInfo: v1.ApplicationInfo = _ + + def replayEventLogs(in: InputStream, sourceName: String): Unit = { + val store: KVStore = createInMemoryStore() + val replayConf: SparkConf = conf.clone().set(ASYNC_TRACKING_ENABLED, false) + val trackingStore: ElementTrackingStore = new ElementTrackingStore(store, replayConf) + val replayBus: ReplayListenerBus = new ReplayListenerBus() + val listener = new AppStatusListener(trackingStore, replayConf, false) + replayBus.addListener(listener) + + try { + replayBus.replay(in, sourceName, maybeTruncated = true) + trackingStore.close(false) + } catch { + case e: Exception => + Utils.tryLogNonFatalError { + trackingStore.close() + } + throw e + } + LOG.info("Replay of logs complete") + val appStatusStore: AppStatusStore = new AppStatusStore(store) + appInfo = appStatusStore.applicationInfo() + environmentInfo = appStatusStore.environmentInfo() + jobData = appStatusStore.jobsList(null) + appStatusStore.close() + } + + def getAppResult(workload: String): AppResult = { + SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(appInfo, workload, environmentInfo, jobData) + } + + private def createInMemoryStore(): KVStore = { + val store = new InMemoryStore() + store + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/TestOmniAdvisor.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/TestOmniAdvisor.java new file mode 100644 index 0000000000000000000000000000000000000000..3372be5638fdefd10cb960a532d394cbfb435d36 --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/TestOmniAdvisor.java @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor; + +import com.huawei.boostkit.omniadvisor.configuration.BaseTestConfiguration; +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import org.junit.Test; + +public class TestOmniAdvisor extends BaseTestConfiguration { + @Test + public void testOmniTuning() { + OmniAdvisor.main(new String[]{"2020-09-02 00:00:00", "2020-09-02 00:00:00", "user", "passwd"}); + } + + @Test(expected = OmniAdvisorException.class) + public void testErrorNumberParams() { + OmniAdvisor.main(new String[] {"2020-09-02 00:00:00", "2020-09-02 00:00:00"}); + } + + @Test(expected = OmniAdvisorException.class) + public void testErrorTimeParser() { + OmniAdvisor.main(new String[] {"2020-09-02 00-00-00", "2020-09-02 00-00-00", "user", "pass"}); + } + + @Test(expected = OmniAdvisorException.class) + public void testErrorTimeOrder() { + OmniAdvisor.main(new String[] {"2020-09-02 20:00:00", "2020:09:02 00-00-00", "user", "pass"}); + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/configuration/BaseTestConfiguration.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/configuration/BaseTestConfiguration.java new file mode 100644 index 0000000000000000000000000000000000000000..b9f4271541dfe403c53df7f7ce90b9e2bcdecdd3 --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/configuration/BaseTestConfiguration.java @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.configuration; + +import com.huawei.boostkit.omniadvisor.OmniAdvisorContext; +import io.ebean.Finder; +import org.apache.commons.configuration2.PropertiesConfiguration; +import org.apache.commons.configuration2.builder.FileBasedConfigurationBuilder; +import org.apache.commons.configuration2.builder.fluent.Configurations; +import org.apache.commons.configuration2.ex.ConfigurationException; +import org.junit.BeforeClass; +import org.mockito.Mockito; + +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.util.Locale; + +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.when; + +public class BaseTestConfiguration { + private static final String TESTING_CONFIG_FILE = "omniAdvisorLogAnalyzer.properties"; + private static final String ENCODING = StandardCharsets.UTF_8.displayName(Locale.ENGLISH); + + protected static PropertiesConfiguration testConfiguration; + + @BeforeClass + public static void setUp() throws ConfigurationException { + Configurations configurations = new Configurations(); + URL configUrl = Thread.currentThread().getContextClassLoader().getResource(TESTING_CONFIG_FILE); + FileBasedConfigurationBuilder.setDefaultEncoding(OmniAdvisorConfigure.class, ENCODING); + testConfiguration = configurations.properties(configUrl); + + OmniAdvisorContext.initContext(); + OmniAdvisorContext context = OmniAdvisorContext.getInstance(); + Finder finder = Mockito.mock(Finder.class); + when(finder.byId(any())).thenReturn(null); + context.setFinder(finder); + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/configuration/TestConfiguration.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/configuration/TestConfiguration.java new file mode 100644 index 0000000000000000000000000000000000000000..752add90b578668a0a2097784f0742925f87dc9f --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/configuration/TestConfiguration.java @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.configuration; + +import org.apache.commons.configuration2.PropertiesConfiguration; +import org.apache.commons.configuration2.builder.FileBasedConfigurationBuilder; +import org.apache.commons.configuration2.builder.fluent.Configurations; +import org.apache.commons.configuration2.ex.ConfigurationException; +import org.junit.BeforeClass; + +import java.net.URL; +import java.nio.charset.StandardCharsets; + +public class TestConfiguration { + private static final String TESTING_CONFIG_FILE_NAME = "TestingConfigure.properties"; + private static final String TESTING_SPARK_CONFIG_FILE_NAME = "TestingSparkConfigure.properties"; + private static final String ENCODING = StandardCharsets.UTF_8.displayName(); + + protected static PropertiesConfiguration testConfiguration; + protected static PropertiesConfiguration testSparkConfiguration; + + @BeforeClass + public static void setUpClass() throws ConfigurationException { + Configurations configurations = new Configurations(); + URL configFileUrl = Thread.currentThread().getContextClassLoader().getResource(TESTING_CONFIG_FILE_NAME); + URL sparkConfig = Thread.currentThread().getContextClassLoader().getResource(TESTING_SPARK_CONFIG_FILE_NAME); + FileBasedConfigurationBuilder.setDefaultEncoding(OmniAdvisorConfigure.class, ENCODING); + testConfiguration = configurations.properties(configFileUrl); + testSparkConfiguration = configurations.properties(sparkConfig); + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/configuration/TestDBConfigure.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/configuration/TestDBConfigure.java new file mode 100644 index 0000000000000000000000000000000000000000..1e4a67e44199f91f4abe8224f482f4e7a8e5bdaa --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/configuration/TestDBConfigure.java @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.configuration; + +import io.ebean.config.DatabaseConfig; +import io.ebean.datasource.DataSourceInitialiseException; +import org.apache.commons.configuration2.PropertiesConfiguration; +import org.junit.Test; +import org.mockito.Mockito; + +import javax.sql.DataSource; + +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.SQLException; + +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.when; + +public class TestDBConfigure { + @Test(expected = DataSourceInitialiseException.class) + public void testDBConfigureWithErrorUrl() { + PropertiesConfiguration testConfiguration = Mockito.mock(PropertiesConfiguration.class); + when(testConfiguration.getString("datasource.db.driver", "com.mysql.cj.jdbc.Driver")) + .thenReturn("com.mysql.cj.jdbc.Driver"); + when(testConfiguration.getString("datasource.db.url")).thenReturn("jdbc://mysql:errorUrl"); + DBConfigure.initDatabase(testConfiguration, "user", "passwd"); + } + + @Test + public void testInitDataBase() throws SQLException { + ResultSet resultSet = Mockito.mock(ResultSet.class); + DatabaseMetaData metaData = Mockito.mock(DatabaseMetaData.class); + Connection connection = Mockito.mock(Connection.class); + DataSource dataSource = Mockito.mock(DataSource.class); + DatabaseConfig dbConfig = Mockito.mock(DatabaseConfig.class); + + when(resultSet.next()).thenReturn(true); + when(metaData.getTables(any(), any(), any(), any())).thenReturn(resultSet); + when(connection.getMetaData()).thenReturn(metaData); + when(dbConfig.getDataSource()).thenReturn(dataSource); + when(dataSource.getConnection()).thenReturn(connection); + DBConfigure.checkInit(dbConfig); + } + + @Test + public void testNotInitDatabase() throws SQLException { + ResultSet resultSet = Mockito.mock(ResultSet.class); + DatabaseMetaData metaData = Mockito.mock(DatabaseMetaData.class); + Connection connection = Mockito.mock(Connection.class); + DataSource dataSource = Mockito.mock(DataSource.class); + DatabaseConfig dbConfig = Mockito.mock(DatabaseConfig.class); + + when(resultSet.next()).thenReturn(false); + when(metaData.getTables(any(), any(), any(), any())).thenReturn(resultSet); + when(connection.getMetaData()).thenReturn(metaData); + when(dbConfig.getDataSource()).thenReturn(dataSource); + when(dataSource.getConnection()).thenReturn(connection); + DBConfigure.checkInit(dbConfig); + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/configuration/TestOmniAdvisorConfigure.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/configuration/TestOmniAdvisorConfigure.java new file mode 100644 index 0000000000000000000000000000000000000000..0fd25f6cf0f9fc1f0c8784a25fb82b73343e2276 --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/configuration/TestOmniAdvisorConfigure.java @@ -0,0 +1,28 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.configuration; + +import com.huawei.boostkit.omniadvisor.OmniAdvisorContext; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class TestOmniAdvisorConfigure extends TestConfiguration { + @Test + public void getOmniTuningConfigure() { + assertEquals(OmniAdvisorContext.getInstance().getOmniAdvisorConfigure().getThreadCount(), 3); + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/exception/TestOmniAdvisorException.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/exception/TestOmniAdvisorException.java new file mode 100644 index 0000000000000000000000000000000000000000..2b2b8efe3d9078badd52356b561674bd040ab1c3 --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/exception/TestOmniAdvisorException.java @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.exception; + +import org.junit.Test; + +public class TestOmniAdvisorException { + @Test(expected = OmniAdvisorException.class) + public void testThrowExceptionWithMessage() { + throw new OmniAdvisorException("errorMessage"); + } + + @Test(expected = OmniAdvisorException.class) + public void testThrowExceptionWithThrowable() { + throw new OmniAdvisorException(new IllegalArgumentException()); + } + + @Test(expected = OmniAdvisorException.class) + public void testThrowExceptionWithMessageAndThrowable() { + throw new OmniAdvisorException("message", new RuntimeException()); + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/executor/TestOmniAdvisorRunner.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/executor/TestOmniAdvisorRunner.java new file mode 100644 index 0000000000000000000000000000000000000000..cd3aff958b22992d717f9c0b12352fcb47107366 --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/executor/TestOmniAdvisorRunner.java @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.executor; + +import com.huawei.boostkit.omniadvisor.OmniAdvisorContext; +import com.huawei.boostkit.omniadvisor.configuration.BaseTestConfiguration; +import com.huawei.boostkit.omniadvisor.fetcher.Fetcher; +import com.huawei.boostkit.omniadvisor.fetcher.FetcherType; +import com.huawei.boostkit.omniadvisor.spark.SparkFetcher; +import org.apache.commons.configuration2.PropertiesConfiguration; +import org.junit.Test; +import org.mockito.Mockito; + +import java.net.URL; + +import static org.mockito.Mockito.when; + +public class TestOmniAdvisorRunner extends BaseTestConfiguration { + @Test + public void testOmniTuningRunner() { + PropertiesConfiguration sparkConfig = Mockito.mock(PropertiesConfiguration.class); + when(sparkConfig.getBoolean("spark.enable", false)).thenReturn(true); + when(sparkConfig.getString("spark.workload", "default")).thenReturn("default"); + when(sparkConfig.getString("spark.eventLogs.mode")).thenReturn("log"); + when(sparkConfig.getInt("spark.timeout.seconds", 30)).thenReturn(30); + URL resource = Thread.currentThread().getContextClassLoader().getResource("spark-events"); + when(sparkConfig.getString("spark.log.directory", "")).thenReturn(resource.getPath()); + when(sparkConfig.getInt("spark.log.maxSize.mb", 500)).thenReturn(500); + Fetcher sparkFetcher = new SparkFetcher(sparkConfig); + + OmniAdvisorContext.getInstance().getFetcherFactory().addFetcher(FetcherType.SPARK, sparkFetcher); + + OmniAdvisorRunner runner = new OmniAdvisorRunner(0L, Long.MAX_VALUE); + runner.run(); + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/fetcher/TestFetcherFactory.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/fetcher/TestFetcherFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..b7d615b779681b6360e68b9a88a818ae74c2bf37 --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/fetcher/TestFetcherFactory.java @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.fetcher; + +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import com.huawei.boostkit.omniadvisor.spark.SparkFetcher; +import org.apache.commons.configuration2.PropertiesConfiguration; +import org.junit.Test; +import org.mockito.Mockito; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.when; + +public class TestFetcherFactory { + @Test + public void testFetcherFactory() { + PropertiesConfiguration config = Mockito.mock(PropertiesConfiguration.class); + + when(config.getBoolean("spark.enable", false)).thenReturn(true); + when(config.getBoolean("tez.enable", false)).thenReturn(false); + + FetcherFactory fetcherFactory = new FetcherFactory(config); + assertEquals(fetcherFactory.getAllFetchers().size(), 1); + } + + @Test + public void testFetcherFactoryWithEmptyFetcher() { + PropertiesConfiguration config = Mockito.mock(PropertiesConfiguration.class); + + when(config.getBoolean("spark.enable", false)).thenReturn(false); + when(config.getBoolean("tez.enable", false)).thenReturn(false); + + FetcherFactory fetcherFactory = new FetcherFactory(config); + assertEquals(fetcherFactory.getAllFetchers().size(), 0); + } + + @Test + public void testGetFetcher() { + PropertiesConfiguration config = Mockito.mock(PropertiesConfiguration.class); + + when(config.getBoolean("spark.enable", false)).thenReturn(true); + when(config.getBoolean("tez.enable", false)).thenReturn(false); + + FetcherFactory fetcherFactory = new FetcherFactory(config); + assertEquals(fetcherFactory.getFetcher(FetcherType.SPARK).getClass(), SparkFetcher.class); + } + + @Test(expected = OmniAdvisorException.class) + public void testGetUnknownFetcher() { + PropertiesConfiguration config = Mockito.mock(PropertiesConfiguration.class); + + when(config.getBoolean("spark.enable", false)).thenReturn(false); + when(config.getBoolean("tez.enable", false)).thenReturn(false); + + FetcherFactory fetcherFactory = new FetcherFactory(config); + fetcherFactory.getFetcher(FetcherType.TEZ); + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/security/TestHadoopSecurity.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/security/TestHadoopSecurity.java new file mode 100644 index 0000000000000000000000000000000000000000..c7cc2f4a4cbfc3712d373d6564bfd232084434c1 --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/security/TestHadoopSecurity.java @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.security; + +import com.huawei.boostkit.omniadvisor.OmniAdvisorContext; +import com.huawei.boostkit.omniadvisor.configuration.OmniAdvisorConfigure; +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.CommonConfigurationKeys; +import org.apache.hadoop.minikdc.MiniKdc; +import org.apache.hadoop.security.UserGroupInformation; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.File; +import java.util.Locale; +import java.util.Properties; + +public class TestHadoopSecurity { + private static Configuration conf; + private static MiniKdc kdc; + private static File keytab; + + @BeforeClass + public static void setupKdc() throws Exception { + conf = new Configuration(); + conf.set(CommonConfigurationKeys.HADOOP_SECURITY_AUTHENTICATION, + UserGroupInformation.AuthenticationMethod.KERBEROS.toString().toLowerCase(Locale.ENGLISH)); + UserGroupInformation.setConfiguration(conf); + + final String principal = "test"; + final File workDir = new File(System.getProperty("test.dir", "target")); + keytab = new File(workDir, "test.keytab"); + + Properties kdcConf = MiniKdc.createConf(); + kdc = new MiniKdc(kdcConf, workDir); + kdc.start(); + kdc.createPrincipal(keytab, principal); + } + + @AfterClass + public static void tearDown() { + UserGroupInformation.reset(); + if (kdc != null) { + kdc.stop(); + } + } + + @After + public void clearProperties() { + OmniAdvisorConfigure configure = OmniAdvisorContext.getInstance().getOmniAdvisorConfigure(); + configure.setKerberosPrincipal(null); + configure.setKerberosKeytabFile(null); + } + + @Test + public void testHadoopSecurity() throws Exception { + OmniAdvisorConfigure configure = OmniAdvisorContext.getInstance().getOmniAdvisorConfigure(); + configure.setKerberosPrincipal("test"); + configure.setKerberosKeytabFile(keytab.getAbsolutePath()); + HadoopSecurity security = new HadoopSecurity(conf); + security.checkLogin(); + } + + @Test(expected = OmniAdvisorException.class) + public void testHadoopSecurityWithoutKeytabUser() throws Exception { + OmniAdvisorConfigure configure = OmniAdvisorContext.getInstance().getOmniAdvisorConfigure(); + configure.setKerberosKeytabFile(keytab.getAbsolutePath()); + HadoopSecurity security = new HadoopSecurity(conf); + security.checkLogin(); + } + + @Test(expected = OmniAdvisorException.class) + public void testHadoopSecurityWithoutKeytabLocation() throws Exception { + OmniAdvisorConfigure configure = OmniAdvisorContext.getInstance().getOmniAdvisorConfigure(); + configure.setKerberosPrincipal("test"); + HadoopSecurity security = new HadoopSecurity(conf); + security.checkLogin(); + } + + @Test(expected = OmniAdvisorException.class) + public void testHadoopSecurityWithErrorKeytabFile() throws Exception { + OmniAdvisorConfigure configure = OmniAdvisorContext.getInstance().getOmniAdvisorConfigure(); + configure.setKerberosPrincipal("test"); + configure.setKerberosKeytabFile("errorPath"); + HadoopSecurity security = new HadoopSecurity(conf); + security.checkLogin(); + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/spark/TestSparkFetcher.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/spark/TestSparkFetcher.java new file mode 100644 index 0000000000000000000000000000000000000000..4c1cd94f24a6e68d552e6fa64c0602964cbf562c --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/spark/TestSparkFetcher.java @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.spark; + +import com.huawei.boostkit.omniadvisor.OmniAdvisorContext; +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob; +import com.huawei.boostkit.omniadvisor.fetcher.FetcherType; +import com.huawei.boostkit.omniadvisor.models.AppResult; +import com.huawei.boostkit.omniadvisor.spark.data.SparkLogAnalyticJob; +import io.ebean.Finder; +import org.apache.commons.configuration2.PropertiesConfiguration; +import org.junit.BeforeClass; +import org.junit.Test; +import org.mockito.Mockito; + +import java.net.URL; +import java.util.List; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.when; + +public class TestSparkFetcher { + private static String testResourcePath; + private static SparkFetcher sparkFetcher; + + @BeforeClass + public static void setUp() { + PropertiesConfiguration sparkConfig = Mockito.mock(PropertiesConfiguration.class); + when(sparkConfig.getBoolean("spark.enable", false)).thenReturn(true); + when(sparkConfig.getString("spark.workload", "default")).thenReturn("default"); + when(sparkConfig.getString("spark.eventLogs.mode")).thenReturn("log"); + when(sparkConfig.getInt("spark.timeout.seconds", 30)).thenReturn(30); + URL resource = Thread.currentThread().getContextClassLoader().getResource("spark-events"); + testResourcePath = resource.getPath(); + when(sparkConfig.getString("spark.log.directory", "")).thenReturn(resource.getPath()); + when(sparkConfig.getInt("spark.log.maxSize.mb", 500)).thenReturn(500); + sparkFetcher = new SparkFetcher(sparkConfig); + } + + @Test + public void testEnable() { + assertTrue(sparkFetcher.isEnable()); + } + + @Test + public void testFetcherType() { + assertEquals(sparkFetcher.getType(), FetcherType.SPARK); + } + + @Test + public void testGetApplications() { + OmniAdvisorContext.initContext(); + Finder finder = Mockito.mock(Finder.class); + when(finder.byId(any())).thenReturn(null); + OmniAdvisorContext.getInstance().setFinder(finder); + List jobs = sparkFetcher.fetchAnalyticJobs(0L, Long.MAX_VALUE); + assertEquals(jobs.size(), 1); + } + + @Test + public void testAnalysis() { + SparkLogAnalyticJob logJob = new SparkLogAnalyticJob("appId", + testResourcePath + System.getProperty("file.separator") + "application_1516285256255_0012"); + Optional result = sparkFetcher.analysis(logJob); + assertTrue(result.isPresent()); + AppResult appResult = result.get(); + assertEquals(appResult.applicationId, "application_1516285256255_0012"); + assertEquals(appResult.applicationName, "Spark shell"); + assertEquals(appResult.applicationWorkload, "default"); + assertEquals(appResult.startTime, 1516300235119L); + assertEquals(appResult.finishTime, 1516300707938L); + assertEquals(appResult.durationTime, 472819L); + assertEquals(appResult.jobType, "SPARK"); + assertEquals(appResult.parameters, "{\"spark.executor.memory\":\"2G\",\"spark.executor.cores\":\"\",\"spark.executor.instances\":\"8\"}"); + assertEquals(appResult.query, ""); + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/spark/client/TestRestClient.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/spark/client/TestRestClient.java new file mode 100644 index 0000000000000000000000000000000000000000..7f136fad9f333b10550f36e697f83af9b934dd92 --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/spark/client/TestRestClient.java @@ -0,0 +1,148 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.spark.client; + +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob; +import com.huawei.boostkit.omniadvisor.configuration.BaseTestConfiguration; +import com.huawei.boostkit.omniadvisor.fetcher.FetcherType; +import com.huawei.boostkit.omniadvisor.models.AppResult; +import com.huawei.boostkit.omniadvisor.spark.data.SparkRestAnalyticJob; +import org.apache.spark.SparkConf; +import org.junit.Test; +import org.mockito.Mockito; +import scala.collection.immutable.List; + +import javax.ws.rs.client.Invocation.Builder; +import javax.ws.rs.client.WebTarget; +import javax.ws.rs.core.MediaType; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.when; + +public class TestRestClient extends BaseTestConfiguration { + private static final String TEST_URL = "http://testUri"; + private static final String TEST_APP_INFO = + "{" + + "\"id\" : \"test\"," + + "\"name\" : \"test\"," + + "\"attempts\" : [{" + + "\"startTime\" : \"2023-09-08T03:10:30.194GMT\"," + + "\"endTime\" : \"2023-09-08T03:11:34.864GMT\"," + + "\"lastUpdated\" : \"2023-09-08T03:11:34.970GMT\"," + + "\"duration\" : 62670," + + "\"sparkUser\" : \"root\"," + + "\"completed\" : true," + + "\"appSparkVersion\" : \"3.1.1\"," + + "\"startTimeEpoch\" : 1694142632194," + + "\"lastUpdatedEpoch\" : 1694142694970," + + "\"endTimeEpoch\" : 1694142694864" + + "}]" + + "}"; + + private static final String TEST_APP_INFO_LIST = "[" + TEST_APP_INFO + "]"; + + private static final String TEST_EMPTY_APP_INFO_LIST = "[{\"id\":\"test\",\"name\":\"test\",\"attempts\":[]}]"; + + @Test + public void testGetApplications() throws URISyntaxException { + SparkRestClient restClient = new SparkRestClient("history-url", 1, new SparkConf(), "default"); + WebTarget webTarget = Mockito.mock(WebTarget.class); + when(webTarget.getUri()).thenReturn(new URI(TEST_URL)); + when(webTarget.path(any())).thenReturn(webTarget); + when(webTarget.queryParam(any(), any())).thenReturn(webTarget); + Builder builder = Mockito.mock(Builder.class); + when(builder.get(String.class)).thenReturn(TEST_APP_INFO_LIST); + when(webTarget.request(MediaType.APPLICATION_JSON)).thenReturn(builder); + restClient.setApiTarget(webTarget); + List jobList = restClient.fetchAnalyticJobs(0L, 100L); + assertEquals(jobList.size(), 1); + } + + @Test + public void testGetEmptyApplication() throws URISyntaxException { + SparkRestClient restClient = new SparkRestClient("history-url", 1, new SparkConf(), "default"); + WebTarget webTarget = Mockito.mock(WebTarget.class); + when(webTarget.getUri()).thenReturn(new URI(TEST_URL)); + when(webTarget.path(any())).thenReturn(webTarget); + when(webTarget.queryParam(any(), any())).thenReturn(webTarget); + Builder builder = Mockito.mock(Builder.class); + when(builder.get(String.class)).thenReturn(TEST_EMPTY_APP_INFO_LIST); + when(webTarget.request(MediaType.APPLICATION_JSON)).thenReturn(builder); + restClient.setApiTarget(webTarget); + List jobList = restClient.fetchAnalyticJobs(0L, 100L); + assertEquals(jobList.size(), 0); + } + + @Test + public void testAnalysis() throws IOException, URISyntaxException { + // build test file + final File workDir = new File(System.getProperty("test.dir", "target")); + + URL filePath = Thread.currentThread().getContextClassLoader() + .getResource("spark-events/application_1516285256255_0012"); + assertNotNull(filePath); + File fileToZip = new File(filePath.getPath()); + + String outZip = workDir + System.getProperty("file.separator") + "output.zip"; + File outputFile = new File(outZip); + + try (FileOutputStream fos = new FileOutputStream(outputFile); + ZipOutputStream zos = new ZipOutputStream(fos)) { + + ZipEntry zipEntry = new ZipEntry(fileToZip.getName()); + zos.putNextEntry(zipEntry); + + try (FileInputStream fis = new FileInputStream(fileToZip)) { + byte[] buffer = new byte[1024]; + int len; + while ((len = fis.read(buffer)) > 0) { + zos.write(buffer, 0, len); + } + } + } + + // test analyze + SparkRestAnalyticJob restJob = new SparkRestAnalyticJob("application_1516285256255_0012"); + assertEquals(restJob.getType(), FetcherType.SPARK); + SparkRestClient restClient = new SparkRestClient("history-url", 1, new SparkConf(), "default"); + WebTarget webTarget = Mockito.mock(WebTarget.class); + when(webTarget.getUri()).thenReturn(new URI(TEST_URL)); + when(webTarget.path(any())).thenReturn(webTarget); + when(webTarget.queryParam(any(), any())).thenReturn(webTarget); + Builder builder = Mockito.mock(Builder.class); + InputStream inputStream = new FileInputStream(outputFile.getAbsoluteFile()); + when(builder.get(InputStream.class)).thenReturn(inputStream); + when(builder.get(String.class)).thenReturn(TEST_APP_INFO); + when(webTarget.request(MediaType.APPLICATION_OCTET_STREAM)).thenReturn(builder); + when(webTarget.request(MediaType.APPLICATION_JSON)).thenReturn(builder); + restClient.setApiTarget(webTarget); + AppResult result = restClient.fetchAnalyticResult(restJob); + assertEquals(result.applicationId, "application_1516285256255_0012"); + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/spark/utils/TestSparkUtils.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/spark/utils/TestSparkUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..66c8078f1d5a7b02e631ea8a51e19d78a08042b3 --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/spark/utils/TestSparkUtils.java @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.spark.utils; + +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import org.apache.hadoop.conf.Configuration; +import org.apache.spark.SparkConf; +import org.apache.spark.io.CompressionCodec; +import org.apache.spark.io.ZStdCompressionCodec; +import org.junit.Test; +import scala.Option; +import scala.collection.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class TestSparkUtils { + @Test + public void testGetPropertiesFromFile() { + String filePath = Thread.currentThread().getContextClassLoader().getResource("test-spark.conf").getPath(); + Map map = SparkUtils.getPropertiesFromFile(filePath); + assertEquals(map.size(), 1); + assertEquals(map.get("spark.master").get(), "yarn"); + } + + @Test(expected = OmniAdvisorException.class) + public void testLoadLogFileFromErrorPath() { + SparkUtils.findApplicationFiles(new Configuration(), "errorPath", 0L, 100L, 500); + } + + @Test + public void getApplicationIdFromFile() { + String fileName = "app_id.ztsd"; + assertEquals(SparkUtils.getApplicationIdFromFile(fileName), "app_id"); + } + + @Test + public void testLoadCompressionCodec() { + SparkConf conf = new SparkConf(); + Option codec = SparkUtils.compressionCodecForLogName(conf, "app_id.zstd"); + assertTrue(codec.isDefined()); + assertEquals(codec.get().getClass(), ZStdCompressionCodec.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testUnKnownLoadCompressionCodec() { + SparkConf conf = new SparkConf(); + SparkUtils.compressionCodecForLogName(conf, "app_id.unknown"); + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/TestTezFetcher.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/TestTezFetcher.java new file mode 100644 index 0000000000000000000000000000000000000000..8e67edc2ae054f6758fef64a4e6b08796a0f0cbb --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/TestTezFetcher.java @@ -0,0 +1,141 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.tez; + +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob; +import com.huawei.boostkit.omniadvisor.configuration.BaseTestConfiguration; +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import com.huawei.boostkit.omniadvisor.fetcher.FetcherType; +import com.huawei.boostkit.omniadvisor.models.AppResult; +import com.huawei.boostkit.omniadvisor.spark.data.SparkRestAnalyticJob; +import com.huawei.boostkit.omniadvisor.tez.data.TezAnalyticJob; +import com.huawei.boostkit.omniadvisor.tez.utils.TestTezClient; +import org.apache.hadoop.security.authentication.client.AuthenticationException; +import org.apache.hadoop.yarn.api.records.YarnApplicationState; +import org.junit.Test; + +import java.io.IOException; +import java.net.MalformedURLException; +import java.util.List; +import java.util.Optional; + +import static com.huawei.boostkit.omniadvisor.tez.utils.TestJsonUtilsFactory.getAppListJsonUtils; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestJsonUtilsFactory.getFailedJsonUtils; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestJsonUtilsFactory.getKilledJsonUtils; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestJsonUtilsFactory.getSuccessJsonUtils; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestJsonUtilsFactory.getUnFinishedJsonUtils; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.FAILED_JOB; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.KILLED_JOB; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.SUCCESS; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.SUCCESS_JOB; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.TIME_14; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.TIME_18; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.UNFINISHED_JOB; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class TestTezFetcher extends BaseTestConfiguration { + @Test + public void testGetApplicationFromTimeline() throws IOException { + TezFetcher fetcher = new TezFetcher(testConfiguration); + fetcher.setTimelineClient(TestTezClient.getTestTimelineClient()); + List job = fetcher.fetchAnalyticJobs(0L, 100L); + assertEquals(job.size(), 1); + } + + @Test + public void testGetType() { + TezFetcher fetcher = new TezFetcher(testConfiguration); + assertEquals(fetcher.getType(), FetcherType.TEZ); + } + + @Test + public void tesGetApplicationsWithError() { + TezFetcher fetcher = new TezFetcher(testConfiguration); + List jobs = fetcher.fetchAnalyticJobs(0L, 100L); + assertTrue(jobs.isEmpty()); + } + + @Test + public void testAnalyzeWithError() { + TezFetcher fetcher = new TezFetcher(testConfiguration); + Optional result = fetcher.analysis(new TezAnalyticJob("id", "name", 0L, 1L, YarnApplicationState.FINISHED)); + assertFalse(result.isPresent()); + } + + @Test + public void testEnable() { + TezFetcher fetcher = new TezFetcher(testConfiguration); + assertFalse(fetcher.isEnable()); + } + + @Test + public void testAnalyze() throws IOException { + TezFetcher fetcher = new TezFetcher(testConfiguration); + fetcher.setTimelineClient(TestTezClient.getTestTimelineClient()); + AnalyticJob testJob = new TezAnalyticJob("test", "test", 0, 100, YarnApplicationState.FINISHED); + fetcher.analysis(testJob); + } + + @Test + public void testGetApplications() throws AuthenticationException, IOException { + TezFetcher fetcher = new TezFetcher(testConfiguration); + fetcher.setTezJsonUtils(getAppListJsonUtils()); + List tezJobs = fetcher.fetchAnalyticJobs(TIME_14, TIME_18); + assertEquals(tezJobs.size(), 4); + } + + @Test(expected = OmniAdvisorException.class) + public void testAnalyzeJobWithErrorType() { + SparkRestAnalyticJob sparkRestAnalyticJob = new SparkRestAnalyticJob("sparkRest"); + TezFetcher fetcher = new TezFetcher(testConfiguration); + fetcher.analysis(sparkRestAnalyticJob); + } + + @Test + public void testAnalyzeJobWithSuccessJob() throws MalformedURLException { + TezFetcher fetcher = new TezFetcher(testConfiguration); + fetcher.setTezJsonUtils(getSuccessJsonUtils()); + Optional successJob = fetcher.analysis(SUCCESS_JOB); + assertTrue(successJob.isPresent()); + assertEquals(successJob.get().applicationId, SUCCESS); + } + + @Test + public void testAnalyzeJobWithFailedJob() throws MalformedURLException { + TezFetcher fetcher = new TezFetcher(testConfiguration); + fetcher.setTezJsonUtils(getFailedJsonUtils()); + Optional failedJob = fetcher.analysis(FAILED_JOB); + assertTrue(failedJob.isPresent()); + } + + @Test + public void testAnalyzeJobWithKilledJob() throws MalformedURLException { + TezFetcher fetcher = new TezFetcher(testConfiguration); + fetcher.setTezJsonUtils(getKilledJsonUtils()); + Optional killedJob = fetcher.analysis(KILLED_JOB); + assertTrue(killedJob.isPresent()); + } + + @Test + public void testAnalyzeJobWithUnFinishedJob() throws MalformedURLException { + TezFetcher fetcher = new TezFetcher(testConfiguration); + fetcher.setTezJsonUtils(getUnFinishedJsonUtils()); + Optional unFinishedJob = fetcher.analysis(UNFINISHED_JOB); + assertTrue(unFinishedJob.isPresent()); + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/data/TestTezData.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/data/TestTezData.java new file mode 100644 index 0000000000000000000000000000000000000000..de71cda9a5c28de3b2dab4754ea3843022aca048 --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/data/TestTezData.java @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.tez.data; + +import com.huawei.boostkit.omniadvisor.spark.data.SparkRestAnalyticJob; +import org.apache.hadoop.yarn.api.records.YarnApplicationState; +import org.apache.tez.dag.app.dag.DAGState; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class TestTezData { + @Test + public void testTezAnalyticJobEquals() { + TezAnalyticJob job1 = new TezAnalyticJob("id", "name", 0L, 1L, YarnApplicationState.RUNNING); + TezAnalyticJob job2 = new TezAnalyticJob("id", "name", 0L, 1L, YarnApplicationState.RUNNING); + TezAnalyticJob job3 = new TezAnalyticJob("no", "nn", 0L, 1L, YarnApplicationState.RUNNING); + SparkRestAnalyticJob restJob = new SparkRestAnalyticJob("id"); + + assertTrue(job1.equals(job1)); + assertTrue(job1.equals(job2)); + assertFalse(job1.equals(job3)); + assertFalse(job1.equals(restJob)); + } + + @Test + public void testTezDagIdEquals() { + TezDagIdData data1 = new TezDagIdData("id", 0L, 1L, 1L, DAGState.SUCCEEDED); + TezDagIdData data2 = new TezDagIdData("id", 0L, 1L, 1L, DAGState.SUCCEEDED); + TezDagIdData data3 = new TezDagIdData("id2", 0L, 1L, 1L, DAGState.SUCCEEDED); + TezAnalyticJob job = new TezAnalyticJob("id", "name", 0L, 1L, YarnApplicationState.RUNNING); + + assertEquals(data1, data1); + assertEquals(data1, data2); + assertFalse(data1.equals(data3)); + assertFalse(data1.equals(job)); + } + + @Test + public void testTezDatIdCompare() { + TezDagIdData data1 = new TezDagIdData("id1", 0L, 1L, 1L, DAGState.SUCCEEDED); + TezDagIdData data2 = new TezDagIdData("id2", 0L, 2L, 2L, DAGState.SUCCEEDED); + assertEquals(0, data1.compareTo(data2)); + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestJsonUtils.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestJsonUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..95aab125c8f418c91681bba03a41934f32ccd735 --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestJsonUtils.java @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.tez.utils; + +import com.sun.jersey.api.client.ClientHandlerException; +import org.apache.hadoop.security.authentication.client.AuthenticationException; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.IOException; +import java.net.ConnectException; +import java.net.MalformedURLException; + +public class TestJsonUtils { + private static TezJsonUtils testJsonUtils; + + @BeforeClass + public static void setUpClass() { + testJsonUtils = new TezJsonUtils(new TezUrlFactory("http://localhost:9999"), false, 10); + } + + @Test(expected = ConnectException.class) + public void testVerifyTimeLineServer() throws IOException { + testJsonUtils.verifyTimeLineServer(); + } + + @Test(expected = ClientHandlerException.class) + public void testGetApplicationJobs() throws AuthenticationException, IOException { + testJsonUtils.getApplicationJobs(0L, 1000L); + } + + @Test(expected = ClientHandlerException.class) + public void testGetDAGIDs() throws MalformedURLException { + testJsonUtils.getDAGIds("appId"); + } + + @Test(expected = ClientHandlerException.class) + public void testGetConfigure() throws MalformedURLException { + testJsonUtils.getConfigure("appId"); + } + + @Test(expected = ClientHandlerException.class) + public void testGetQueryString() throws MalformedURLException { + testJsonUtils.getQueryString("appId"); + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestJsonUtilsFactory.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestJsonUtilsFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..d1d046503db130e5e1cb2fe14332d47750713bf6 --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestJsonUtilsFactory.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.tez.utils; + +import com.google.common.collect.ImmutableList; +import org.apache.hadoop.security.authentication.client.AuthenticationException; +import org.mockito.Mockito; + +import java.io.IOException; +import java.net.MalformedURLException; + +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.FAILED; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.FAILED_DAG; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.KILLED; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.KILLED_DAG; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.SUCCESS; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.SUCCESS_DAG; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.TEST_APP_LIST; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.TEST_TEZ_CONFIGURE; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.TEST_TEZ_QUERY; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.TIME_14; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.TIME_18; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.UNFINISHED; +import static com.huawei.boostkit.omniadvisor.tez.utils.TestTezContext.UNFINISHED_DAG; + +public class TestJsonUtilsFactory { + public static TezJsonUtils getAppListJsonUtils() throws AuthenticationException, IOException { + TezJsonUtils tezJsonUtils = Mockito.mock(TezJsonUtils.class); + Mockito.when(tezJsonUtils.getApplicationJobs(TIME_14, TIME_18)).thenReturn(TEST_APP_LIST); + return tezJsonUtils; + } + + public static TezJsonUtils getSuccessJsonUtils() throws MalformedURLException { + TezJsonUtils successJsonUtils = Mockito.mock(TezJsonUtils.class); + Mockito.when(successJsonUtils.getDAGIds(SUCCESS)).thenReturn(ImmutableList.of(SUCCESS_DAG)); + Mockito.when(successJsonUtils.getConfigure(SUCCESS)).thenReturn(TEST_TEZ_CONFIGURE); + Mockito.when(successJsonUtils.getQueryString(SUCCESS)).thenReturn(TEST_TEZ_QUERY); + return successJsonUtils; + } + + public static TezJsonUtils getFailedJsonUtils() throws MalformedURLException { + TezJsonUtils failedJsonUtils = Mockito.mock(TezJsonUtils.class); + Mockito.when(failedJsonUtils.getDAGIds(FAILED)).thenReturn(ImmutableList.of(FAILED_DAG)); + Mockito.when(failedJsonUtils.getConfigure(FAILED)).thenReturn(TEST_TEZ_CONFIGURE); + Mockito.when(failedJsonUtils.getQueryString(FAILED)).thenReturn(TEST_TEZ_QUERY); + return failedJsonUtils; + } + + public static TezJsonUtils getKilledJsonUtils() throws MalformedURLException { + TezJsonUtils killedJsonUtils = Mockito.mock(TezJsonUtils.class); + Mockito.when(killedJsonUtils.getDAGIds(KILLED)).thenReturn(ImmutableList.of(KILLED_DAG)); + Mockito.when(killedJsonUtils.getConfigure(KILLED)).thenReturn(TEST_TEZ_CONFIGURE); + Mockito.when(killedJsonUtils.getQueryString(KILLED)).thenReturn(TEST_TEZ_QUERY); + return killedJsonUtils; + } + + public static TezJsonUtils getUnFinishedJsonUtils() throws MalformedURLException { + TezJsonUtils unFinishedJsonUtils = Mockito.mock(TezJsonUtils.class); + Mockito.when(unFinishedJsonUtils.getDAGIds(UNFINISHED)).thenReturn(ImmutableList.of(UNFINISHED_DAG)); + Mockito.when(unFinishedJsonUtils.getConfigure(UNFINISHED)).thenReturn(TEST_TEZ_CONFIGURE); + Mockito.when(unFinishedJsonUtils.getQueryString(UNFINISHED)).thenReturn(TEST_TEZ_QUERY); + return unFinishedJsonUtils; + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestTezClient.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestTezClient.java new file mode 100644 index 0000000000000000000000000000000000000000..3be61a4cc05d71115cd76481ce501e8d0b627645 --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestTezClient.java @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.tez.utils; + +import org.codehaus.jackson.map.ObjectMapper; +import org.mockito.Mockito; + +import java.io.IOException; +import java.net.URL; + +import static org.mockito.Mockito.when; + +public class TestTezClient { + private static final ObjectMapper MAPPER = new ObjectMapper(); + + private static final String TEST_APP_STRING = + "{" + + "\"app\":[" + + "{" + + "\"appId\":\"application_test\"," + + "\"name\":\"sql_test\"," + + "\"appState\":\"FINISHED\"" + + "}" + + "]" + + "}"; + + private static final String TEST_APP_INFO = + "{" + + "\"entitytype\":\"TEZ_APPLICATION\"," + + "\"otherinfo\":{" + + "\"config\":{" + + "\"tez.am.resource.memory.mb\":1024," + + "\"tez.am.resource.cpu.vcores\":5," + + "\"tez.task.resource.memory.mb\":1024," + + "\"tez.task.reource.cpu.vcores\":5" + + "}" + + "}" + + "}"; + + private static final String TEST_DAG_INFO = + "{" + + "\"entities\":[" + + "{" + + "\"entitytype\":\"TEZ_DAG_ID\"," + + "\"entity\":\"dag_test_1\"," + + "\"otherinfo\":{" + + "\"startTime\":0," + + "\"timeTaken\":100," + + "\"endTime\":100," + + "\"status\":\"SUCCEEDED\"" + + "}" + + "}" + + "]" + + "}"; + + private static final String TEST_DAG_EXTRA_INFO = + "{" + + "\"entitytype\":\"TEZ_DAG_EXTRA_INFO\"," + + "\"otherinfo\":{" + + "\"dagPlan\":{" + + "\"dagContext\":{" + + "\"description\":\"select * from table\"" + + "}" + + "}" + + "}" + + "}"; + + public static TimelineClient getTestTimelineClient() throws IOException { + TimelineClient testClient = Mockito.mock(TimelineClient.class); + when(testClient.readJsonNode(new URL("http://testUrl:8188/ws/v1/applicationhistory/apps?applicationTypes=TEZ&startedTimeBegin=0&startedTimeEnd=100"))) + .thenReturn(MAPPER.readTree(TEST_APP_STRING)); + when(testClient.readJsonNode(new URL("http://testUrl:8188/ws/v1/timeline/TEZ_APPLICATION/tez_test"))) + .thenReturn(MAPPER.readTree(TEST_APP_INFO)); + when(testClient.readJsonNode(new URL("http://testUrl:8188/ws/v1/timeline/TEZ_DAG_ID?primaryFilter=applicationId:test"))) + .thenReturn(MAPPER.readTree(TEST_DAG_INFO)); + when((testClient.readJsonNode(new URL("http://testUrl:8188/ws/v1/timeline/TEZ_DAG_EXTRA_INFO/dag_test_1")))) + .thenReturn(MAPPER.readTree(TEST_DAG_EXTRA_INFO)); + return testClient; + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestTezContext.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestTezContext.java new file mode 100644 index 0000000000000000000000000000000000000000..25526b18fff384a699fe25f9606bedfee38c553f --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestTezContext.java @@ -0,0 +1,91 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.tez.utils; + +import com.google.common.collect.ImmutableBiMap; +import com.google.common.collect.ImmutableList; +import com.huawei.boostkit.omniadvisor.analysis.AnalyticJob; +import com.huawei.boostkit.omniadvisor.exception.OmniAdvisorException; +import com.huawei.boostkit.omniadvisor.tez.data.TezAnalyticJob; +import com.huawei.boostkit.omniadvisor.tez.data.TezDagIdData; +import com.huawei.boostkit.omniadvisor.utils.MathUtils; +import org.apache.hadoop.yarn.api.records.YarnApplicationState; +import org.apache.tez.dag.app.dag.DAGState; + +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.List; +import java.util.Map; + +public class TestTezContext { + public static final SimpleDateFormat DF = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); + + public static final String DATE_14 = "2023-09-02 14:00:00"; + public static final String DATE_15 = "2023-09-02 15:00:00"; + public static final String DATE_16 = "2023-09-02 16:00:00"; + public static final String DATE_17 = "2023-09-02 17:00:00"; + public static final String DATE_18 = "2023-09-02 18:00:00"; + + public static final long TIME_14; + public static final long TIME_15; + public static final long TIME_16; + public static final long TIME_17; + public static final long TIME_18; + + static { + try { + TIME_14 = DF.parse(DATE_14).getTime(); + TIME_15 = DF.parse(DATE_15).getTime(); + TIME_16 = DF.parse(DATE_16).getTime(); + TIME_17 = DF.parse(DATE_17).getTime(); + TIME_18 = DF.parse(DATE_18).getTime(); + } catch (ParseException e) { + throw new OmniAdvisorException("Parse time failed", e); + } + } + + public static final String SUCCESS = "success"; + public static final String FAILED = "failed"; + public static final String KILLED = "killed"; + public static final String UNFINISHED = "UNFINISHED"; + + public static final AnalyticJob SUCCESS_JOB = + new TezAnalyticJob(SUCCESS, SUCCESS, TIME_14, TIME_15, YarnApplicationState.FINISHED); + public static final AnalyticJob FAILED_JOB = + new TezAnalyticJob(FAILED, FAILED, TIME_15, TIME_16, YarnApplicationState.FINISHED); + public static final AnalyticJob KILLED_JOB = + new TezAnalyticJob(KILLED, KILLED, TIME_16, TIME_17, YarnApplicationState.KILLED); + public static final AnalyticJob UNFINISHED_JOB = + new TezAnalyticJob(UNFINISHED, UNFINISHED, TIME_17, TIME_18, YarnApplicationState.RUNNING); + + public static final TezDagIdData SUCCESS_DAG = + new TezDagIdData(SUCCESS, TIME_14, TIME_15, MathUtils.HOUR_IN_MS, DAGState.SUCCEEDED); + public static final TezDagIdData FAILED_DAG = + new TezDagIdData(FAILED, TIME_15, TIME_16, MathUtils.HOUR_IN_MS, DAGState.FAILED); + public static final TezDagIdData KILLED_DAG = + new TezDagIdData(KILLED, TIME_16, TIME_17, MathUtils.HOUR_IN_MS, DAGState.RUNNING); + public static final TezDagIdData UNFINISHED_DAG = + new TezDagIdData(UNFINISHED, TIME_17, TIME_18, MathUtils.HOUR_IN_MS, DAGState.RUNNING); + + public static final List TEST_APP_LIST = + ImmutableList.of(SUCCESS_JOB, FAILED_JOB, KILLED_JOB, UNFINISHED_JOB); + + public static final String TEST_TEZ_QUERY = "select id, name from table"; + + public static final Map TEST_TEZ_CONFIGURE = ImmutableBiMap.of( + "tez.am.resource.memory.mb", "200", "tez.am.resource.cpu.vcores", "2", + "tez.task.resource.memory.mb", "300", "tez.task.resource.cpu.vcores", "4"); +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestTimelineClient.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestTimelineClient.java new file mode 100644 index 0000000000000000000000000000000000000000..f6624f38a0064fc53018e03b5df8c7df728142b8 --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestTimelineClient.java @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.tez.utils; + +import com.sun.jersey.api.client.Client; +import com.sun.jersey.api.client.ClientResponse; +import com.sun.jersey.api.client.WebResource; +import org.apache.hadoop.conf.Configuration; +import org.codehaus.jackson.JsonNode; +import org.codehaus.jettison.json.JSONException; +import org.codehaus.jettison.json.JSONObject; + +import org.junit.Test; +import org.mockito.Mockito; + +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; + +import java.net.MalformedURLException; +import java.net.URL; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.when; + +public class TestTimelineClient { + @Test + public void testReadJsonNode() throws MalformedURLException, JSONException { + try (TimelineClient timelineClient = new TimelineClient(new Configuration(), false, 6000)) { + String testUrl = "http://test-url:8188/test"; + + ClientResponse response = Mockito.mock(ClientResponse.class); + when(response.getStatus()).thenReturn(Response.Status.OK.getStatusCode()); + JSONObject jsonObject = new JSONObject("{\"name\" : \"test\"}"); + when(response.getEntity(JSONObject.class)).thenReturn(jsonObject); + WebResource resource = Mockito.mock(WebResource.class); + WebResource.Builder builder = Mockito.mock(WebResource.Builder.class); + when(resource.accept(MediaType.APPLICATION_JSON_TYPE)).thenReturn(builder); + when(builder.type(MediaType.APPLICATION_JSON_TYPE)).thenReturn(builder); + when(builder.get(ClientResponse.class)).thenReturn(response); + + Client httpClient = Mockito.mock(Client.class); + when(httpClient.resource(testUrl)).thenReturn(resource); + timelineClient.setClient(httpClient); + JsonNode object = timelineClient.readJsonNode(new URL(testUrl)); + assertEquals(object.get("name").getTextValue(), "test"); + } + } +} diff --git a/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestUrlFactory.java b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestUrlFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..3f5be00489f7c48eaedcd1f5eb2fbf4def04e849 --- /dev/null +++ b/omniadvisor/src/test/java/com/huawei/boostkit/omniadvisor/tez/utils/TestUrlFactory.java @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omniadvisor.tez.utils; + +import org.junit.BeforeClass; +import org.junit.Test; + +import java.net.MalformedURLException; + +import static org.junit.Assert.assertEquals; + +public class TestUrlFactory { + private static final String BASE_URL = "http://localhost:8088"; + private static TezUrlFactory urlFactory; + + @BeforeClass + public static void setUpClass() { + urlFactory = new TezUrlFactory(BASE_URL); + } + + @Test + public void testGetRootURL() throws MalformedURLException { + assertEquals(urlFactory.getRootURL().toString(), "http://localhost:8088/ws/v1/timeline"); + } + + @Test + public void testGetApplicationURL() throws MalformedURLException { + assertEquals(urlFactory.getApplicationURL("appId").toString(), + "http://localhost:8088/ws/v1/timeline/TEZ_APPLICATION/tez_appId"); + } + + @Test + public void testGetDagIdURL() throws MalformedURLException { + assertEquals(urlFactory.getDagIdURL("appId").toString(), + "http://localhost:8088/ws/v1/timeline/TEZ_DAG_ID?primaryFilter=applicationId:appId"); + } + + @Test + public void testGetDagExtraInfoURL() throws MalformedURLException { + assertEquals(urlFactory.getDagExtraInfoURL("dagId").toString(), + "http://localhost:8088/ws/v1/timeline/TEZ_DAG_EXTRA_INFO/dagId"); + } + + @Test + public void testGetApplicationHistoryURL() throws MalformedURLException { + assertEquals(urlFactory.getApplicationHistoryURL(0L, 1L).toString(), + "http://localhost:8088/ws/v1/applicationhistory/apps?applicationTypes=TEZ&startedTimeBegin=0&startedTimeEnd=1"); + } +} diff --git a/omniadvisor/src/test/java/org/apache/spark/TestSparkApplicationDataExtractor.java b/omniadvisor/src/test/java/org/apache/spark/TestSparkApplicationDataExtractor.java new file mode 100644 index 0000000000000000000000000000000000000000..b60bf8455022f2cd3d6700cef019ad02867b3aea --- /dev/null +++ b/omniadvisor/src/test/java/org/apache/spark/TestSparkApplicationDataExtractor.java @@ -0,0 +1,113 @@ +/* + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import com.google.common.collect.ImmutableList; +import com.huawei.boostkit.omniadvisor.models.AppResult; +import org.apache.spark.status.api.v1.ApplicationAttemptInfo; +import org.apache.spark.status.api.v1.ApplicationEnvironmentInfo; +import org.apache.spark.status.api.v1.ApplicationInfo; +import org.apache.spark.status.api.v1.JobData; +import org.junit.BeforeClass; +import org.junit.Test; +import org.mockito.Mockito; +import scala.Option; +import scala.Tuple2; +import scala.collection.immutable.HashMap; + +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.when; +import static scala.collection.JavaConverters.asScalaBuffer; + +public class TestSparkApplicationDataExtractor { + private static final String TEST_WORK_LOAD = "default"; + + private static ApplicationEnvironmentInfo environmentInfo; + private static ApplicationAttemptInfo completeAttemptInfo; + private static ApplicationAttemptInfo unCompleteAttemptInfo; + + private static JobData jobData; + private static JobData failedData; + private static JobData runningData; + + @BeforeClass + public static void setUp() throws ParseException { + List> configs = ImmutableList.of( + new Tuple2<>("spark.executor.memory", "1g"), + new Tuple2<>("spark.executor.cores", "1"), + new Tuple2<>("spark.executor.instances", "1")); + environmentInfo = Mockito.mock(ApplicationEnvironmentInfo.class); + when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(configs)); + + SimpleDateFormat format = new SimpleDateFormat("yyyy-MM-dd hh:mm:ss"); + Date startDate = format.parse("2020-05-01 18:00:00"); + Date endDate = format.parse("2020-05-01 18:00:01"); + + completeAttemptInfo = new ApplicationAttemptInfo(Option.apply("attemptId"), startDate, endDate, endDate, 1000L, "user", true, "3.1.1"); + unCompleteAttemptInfo = new ApplicationAttemptInfo(Option.apply("attemptId"), startDate, endDate, endDate, 1000L, "user", false, "3.1.1"); + + jobData = new JobData(1, "jobName", Option.empty(), Option.empty(), Option.empty(), asScalaBuffer(ImmutableList.of()), Option.empty(), JobExecutionStatus.SUCCEEDED, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, new HashMap<>()); + failedData = new JobData(1, "jobName", Option.empty(), Option.empty(), Option.empty(), asScalaBuffer(ImmutableList.of()), Option.empty(), JobExecutionStatus.FAILED, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, new HashMap<>()); + runningData = new JobData(1, "jobName", Option.empty(), Option.empty(), Option.empty(), asScalaBuffer(ImmutableList.of()), Option.empty(), JobExecutionStatus.RUNNING, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, new HashMap<>()); + } + + @Test + public void testExtractData() { + ApplicationInfo applicationInfo = new ApplicationInfo("id", "name", Option.empty(), Option.empty(), Option.empty(), Option.empty(), asScalaBuffer(ImmutableList.of(completeAttemptInfo))); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of(jobData))); + assertEquals(result.applicationId, "id"); + assertEquals(result.executionStatus, AppResult.SUCCEEDED_STATUS); + assertEquals(result.durationTime, 1000L); + } + + @Test + public void testExtractDataWithUnCompleteApplication() { + ApplicationInfo applicationInfo = new ApplicationInfo("id", "name", Option.empty(), Option.empty(), Option.empty(), Option.empty(), asScalaBuffer(ImmutableList.of(unCompleteAttemptInfo))); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of(runningData))); + assertEquals(result.applicationId, "id"); + assertEquals(result.executionStatus, AppResult.FAILED_STATUS); + assertEquals(result.durationTime, AppResult.FAILED_JOB_DURATION); + } + + @Test + public void testExtractDataWithFailedApplication() { + ApplicationInfo applicationInfo = new ApplicationInfo("id", "name", Option.empty(), Option.empty(), Option.empty(), Option.empty(), asScalaBuffer(ImmutableList.of(completeAttemptInfo))); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of(failedData))); + assertEquals(result.applicationId, "id"); + assertEquals(result.executionStatus, AppResult.FAILED_STATUS); + assertEquals(result.durationTime, AppResult.FAILED_JOB_DURATION); + } + + @Test + public void testExtractDataWithEmptyJob() { + ApplicationInfo applicationInfo = new ApplicationInfo("id", "name", Option.empty(), Option.empty(), Option.empty(), Option.empty(), asScalaBuffer(ImmutableList.of(completeAttemptInfo))); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of())); + assertEquals(result.applicationId, "id"); + assertEquals(result.executionStatus, AppResult.FAILED_STATUS); + assertEquals(result.durationTime, AppResult.FAILED_JOB_DURATION); + } + + @Test(expected = IllegalArgumentException.class) + public void testExtractDataWithEmptyApplication() { + ApplicationInfo applicationInfo = new ApplicationInfo("id", "name", Option.empty(), Option.empty(), Option.empty(), Option.empty(), asScalaBuffer(ImmutableList.of())); + SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of())); + } +} diff --git a/omniadvisor/src/test/resources/SparkParams b/omniadvisor/src/test/resources/SparkParams new file mode 100644 index 0000000000000000000000000000000000000000..f90561fd88b1e9e81f07d18483e1801935961e38 --- /dev/null +++ b/omniadvisor/src/test/resources/SparkParams @@ -0,0 +1,3 @@ +spark.executor.memory +spark.executor.cores +spark.executor.instances diff --git a/omniadvisor/src/test/resources/TezParams b/omniadvisor/src/test/resources/TezParams new file mode 100644 index 0000000000000000000000000000000000000000..7a42fcc60d453a92972da6b325c6c39cef2d8125 --- /dev/null +++ b/omniadvisor/src/test/resources/TezParams @@ -0,0 +1,4 @@ +tez.am.resource.memory.mb +tez.am.resource.cpu.vcores +tez.task.resource.memory.mb +tez.task.resource.cpu.vcores \ No newline at end of file diff --git a/omniadvisor/src/test/resources/omniAdvisorLogAnalyzer.properties b/omniadvisor/src/test/resources/omniAdvisorLogAnalyzer.properties new file mode 100644 index 0000000000000000000000000000000000000000..575cc4581c16db075263cd1cd8e1871626474c3b --- /dev/null +++ b/omniadvisor/src/test/resources/omniAdvisorLogAnalyzer.properties @@ -0,0 +1,11 @@ +log.analyzer.thread.count=3 + +datasource.db.driver=com.mysql.cj.jdbc.Driver +datasource.db.url=url + +spark.enable=false + +tez.enable=true +tez.workload=workload +tez.timeline.url=http://testUrl:8188 +tez.timeline.timeout.ms=6000 diff --git a/omniadvisor/src/test/resources/spark-events/application_1516285256255_0012 b/omniadvisor/src/test/resources/spark-events/application_1516285256255_0012 new file mode 100644 index 0000000000000000000000000000000000000000..3e1736c3fe22494c5a4054e2473473cd2301a779 --- /dev/null +++ b/omniadvisor/src/test/resources/spark-events/application_1516285256255_0012 @@ -0,0 +1,71 @@ +{"Event":"SparkListenerLogStart","Spark Version":"2.3.0-SNAPSHOT"} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre","Java Version":"1.8.0_161 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.blacklist.enabled":"true","spark.driver.host":"apiros-1.gce.test.com","spark.eventLog.enabled":"true","spark.driver.port":"33058","spark.repl.class.uri":"spark://apiros-1.gce.test.com:33058/classes","spark.jars":"","spark.repl.class.outputDir":"/tmp/spark-6781fb17-e07a-4b32-848b-9936c2e88b33/repl-c0fd7008-04be-471e-a173-6ad3e62d53d7","spark.app.name":"Spark shell","spark.blacklist.stage.maxFailedExecutorsPerNode":"1","spark.scheduler.mode":"FIFO","spark.executor.instances":"8","spark.ui.showConsoleProgress":"true","spark.blacklist.stage.maxFailedTasksPerExecutor":"1","spark.executor.id":"driver","spark.submit.deployMode":"client","spark.master":"yarn","spark.ui.filters":"org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter","spark.executor.memory":"2G","spark.home":"/github/spark","spark.sql.catalogImplementation":"hive","spark.driver.appUIAddress":"http://apiros-1.gce.test.com:4040","spark.blacklist.application.maxFailedTasksPerExecutor":"10","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_HOSTS":"apiros-1.gce.test.com","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_URI_BASES":"http://apiros-1.gce.test.com:8088/proxy/application_1516285256255_0012","spark.app.id":"application_1516285256255_0012"},"System Properties":{"java.io.tmpdir":"/tmp","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"*********(redacted)","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","sun.arch.data.model":"64","sun.boot.library.path":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/amd64","user.dir":"*********(redacted)","java.library.path":"/usr/java/packages/lib/amd64:/usr/lib64:/lib64:/lib:/usr/lib","sun.cpu.isalist":"","os.arch":"amd64","java.vm.version":"25.161-b14","java.endorsed.dirs":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/endorsed","java.runtime.version":"1.8.0_161-b14","java.vm.info":"mixed mode","java.ext.dirs":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/ext:/usr/java/packages/lib/ext","java.runtime.name":"OpenJDK Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/resources.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/rt.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/sunrsasign.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/jsse.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/jce.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/charsets.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/jfr.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/classes","file.encoding":"UTF-8","user.timezone":"*********(redacted)","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"3.10.0-693.5.2.el7.x86_64","sun.os.patch.level":"unknown","java.vm.specification.vendor":"Oracle Corporation","user.country":"*********(redacted)","sun.jnu.encoding":"UTF-8","user.language":"*********(redacted)","java.vendor.url":"*********(redacted)","java.awt.printerjob":"sun.print.PSPrinterJob","java.awt.graphicsenv":"sun.awt.X11GraphicsEnvironment","awt.toolkit":"sun.awt.X11.XToolkit","os.name":"Linux","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"*********(redacted)","user.name":"*********(redacted)","java.vm.name":"OpenJDK 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master yarn --deploy-mode client --conf spark.blacklist.stage.maxFailedTasksPerExecutor=1 --conf spark.blacklist.enabled=true --conf spark.blacklist.application.maxFailedTasksPerExecutor=10 --conf spark.blacklist.stage.maxFailedExecutorsPerNode=1 --conf spark.eventLog.enabled=true --class org.apache.spark.repl.Main --name Spark shell --executor-memory 2G --num-executors 8 spark-shell","java.home":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre","java.version":"1.8.0_161","sun.io.unicode.encoding":"UnicodeLittle"},"Classpath Entries":{"/github/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arrow-vector-0.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-io-2.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-hive_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stax-api-1.0-2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/apache-log4j-extras-1.2.17.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hive-metastore-1.2.1.spark2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar":"System Classpath","/github/spark/conf/":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/aircompressor-0.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stax-api-1.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/curator-recipes-2.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arrow-format-0.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/libthrift-0.9.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-lang-2.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stringtemplate-3.2.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/datanucleus-api-jdo-3.2.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-net-2.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/orc-core-1.4.1-nohive.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spire_2.11-0.13.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arrow-memory-0.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scalap-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/JavaEWAH-0.3.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/bcprov-jdk15on-1.58.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javolution-5.5.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/libfb303-0.9.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jodd-core-3.5.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/janino-3.0.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/snappy-0.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/guice-3.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/java-xmlbuilder-1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/apacheds-kerberos-codec-2.0.0-M15.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/ST4-4.0.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/datanucleus-core-3.2.10.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/guice-servlet-3.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hive-exec-1.2.1.spark2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jetty-6.1.26.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/base64-2.3.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar":"System Classpath","/etc/hadoop/conf/":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-yarn_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.13.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-client-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-jaxrs-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/apacheds-i18n-2.0.0-M15.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/lz4-java-1.4.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.13.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/curator-framework-2.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/curator-client-2.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/netty-3.9.9.Final.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/calcite-avatica-1.2.0-incubating.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/machinist_2.11-0.6.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jaxb-api-2.2.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/api-asn1-api-1.0.0-M20.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/calcite-linq4j-1.2.0-incubating.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/orc-mapreduce-1.4.1-nohive.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xercesImpl-2.9.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hppc-0.7.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.3.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/activation-1.1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/py4j-0.10.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-hadoop-bundle-1.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/antlr-runtime-3.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/eigenbase-properties-1.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/paranamer-2.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jta-1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/derby-10.12.1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xz-1.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-logging-1.1.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-pool-1.5.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/htrace-core-3.0.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/httpclient-4.5.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/zstd-jni-1.3.2-2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-web-proxy-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-kvstore_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.5.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/chill-java-0.8.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/antlr-2.7.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/datanucleus-rdbms-3.2.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.inject-1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/api-util-1.0.0-M20.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jdo-api-3.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.7.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-dbcp-1.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/netty-all-4.1.17.Final.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/gson-2.2.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/calcite-core-1.2.0-incubating.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/macro-compat_2.11-1.1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/flatbuffers-1.2.0-3f79e055.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.13.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jets3t-0.9.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-xc-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/aopalliance-1.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/bonecp-0.8.0.RELEASE.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/joda-time-2.9.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/httpcore-4.4.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.8.jar":"System Classpath"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"application_1516285256255_0012","Timestamp":1516300235119,"User":"attilapiros"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300252095,"Executor ID":"2","Executor Info":{"Host":"apiros-3.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096","stderr":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"apiros-3.gce.test.com","Port":38670},"Maximum Memory":956615884,"Timestamp":1516300252260,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300252715,"Executor ID":"3","Executor Info":{"Host":"apiros-2.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stdout?start=-4096","stderr":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300252918,"Executor ID":"1","Executor Info":{"Host":"apiros-3.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096","stderr":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"apiros-2.gce.test.com","Port":38641},"Maximum Memory":956615884,"Timestamp":1516300252959,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"apiros-3.gce.test.com","Port":34970},"Maximum Memory":956615884,"Timestamp":1516300252988,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300253542,"Executor ID":"4","Executor Info":{"Host":"apiros-2.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stdout?start=-4096","stderr":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"4","Host":"apiros-2.gce.test.com","Port":33229},"Maximum Memory":956615884,"Timestamp":1516300253653,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300254323,"Executor ID":"5","Executor Info":{"Host":"apiros-2.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stdout?start=-4096","stderr":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"5","Host":"apiros-2.gce.test.com","Port":45147},"Maximum Memory":956615884,"Timestamp":1516300254385,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1516300392631,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :27","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :27","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :27","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]},{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :30","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :30","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:30)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]}],"Stage IDs":[0,1],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :27","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :27","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :27","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300392658,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1516300392816,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1516300394320,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":6,"Index":6,"Attempt":0,"Launch Time":1516300394323,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklistedForStage","time":1516300394348,"executorId":"5","taskFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"org.apache.spark.scheduler.SparkListenerNodeBlacklistedForStage","time":1516300394348,"hostId":"apiros-2.gce.test.com","executorFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklistedForStage","time":1516300394356,"executorId":"4","taskFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"1332","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":5,"Update":"33","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"3075188","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394338,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3075188,"Value":3075188,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":33,"Value":33,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1332,"Value":1332,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":1332,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":33,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":3075188,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"1184","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":5,"Update":"82","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"16858066","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394355,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":16858066,"Value":19933254,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":82,"Value":115,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1184,"Value":2516,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":1184,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":82,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":16858066,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"51","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"183718","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":6,"Index":6,"Attempt":0,"Launch Time":1516300394323,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394390,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":183718,"Value":20116972,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":51,"Value":2567,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":51,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":183718,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"27","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"191901","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1516300394320,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394393,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":191901,"Value":20308873,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":27,"Value":2594,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":27,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":191901,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394606,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3322956,"Value":23631829,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":144,"Value":144,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":1080,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":78,"Value":193,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1134,"Value":1134,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":278399617,"Value":278399617,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":493,"Value":3087,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":263386625,"Value":263386625,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1206,"Value":1206,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1206,"Executor Deserialize CPU Time":263386625,"Executor Run Time":493,"Executor CPU Time":278399617,"Result Size":1134,"JVM GC Time":78,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":144,"Shuffle Write Time":3322956,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":7,"Index":5,"Attempt":1,"Launch Time":1516300394859,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394860,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3587839,"Value":27219668,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":291,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":2160,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":102,"Value":295,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1134,"Value":2268,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":349920830,"Value":628320447,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":681,"Value":3768,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":365807898,"Value":629194523,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1282,"Value":2488,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1282,"Executor Deserialize CPU Time":365807898,"Executor Run Time":681,"Executor CPU Time":349920830,"Result Size":1134,"JVM GC Time":102,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":3587839,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":8,"Index":6,"Attempt":1,"Launch Time":1516300394879,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1516300392816,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394880,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3662221,"Value":30881889,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":9,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":144,"Value":435,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":3240,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":75,"Value":370,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1134,"Value":3402,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":368865439,"Value":997185886,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":914,"Value":4682,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":353981050,"Value":983175573,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1081,"Value":3569,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1081,"Executor Deserialize CPU Time":353981050,"Executor Run Time":914,"Executor CPU Time":368865439,"Result Size":1134,"JVM GC Time":75,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":144,"Shuffle Write Time":3662221,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":9,"Index":4,"Attempt":1,"Launch Time":1516300394973,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":7,"Index":5,"Attempt":1,"Launch Time":1516300394859,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394974,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":377601,"Value":31259490,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":12,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":582,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":4320,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":4450,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":28283110,"Value":1025468996,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":84,"Value":4766,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":10894331,"Value":994069904,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":11,"Value":3580,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":11,"Executor Deserialize CPU Time":10894331,"Executor Run Time":84,"Executor CPU Time":28283110,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":377601,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":10,"Index":1,"Attempt":1,"Launch Time":1516300395069,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":9,"Index":4,"Attempt":1,"Launch Time":1516300394973,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395069,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":366050,"Value":31625540,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":15,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":729,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":5400,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1091,"Value":5541,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":25678331,"Value":1051147327,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":48,"Value":4814,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4793905,"Value":998863809,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":3585,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":5,"Executor Deserialize CPU Time":4793905,"Executor Run Time":48,"Executor CPU Time":25678331,"Result Size":1091,"JVM GC Time":0,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":366050,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":11,"Index":7,"Attempt":0,"Launch Time":1516300395072,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":8,"Index":6,"Attempt":1,"Launch Time":1516300394879,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395073,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":311940,"Value":31937480,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":18,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":876,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":6480,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":6589,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":27304550,"Value":1078451877,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":54,"Value":4868,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":12246145,"Value":1011109954,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":56,"Value":3641,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":56,"Executor Deserialize CPU Time":12246145,"Executor Run Time":54,"Executor CPU Time":27304550,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":311940,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":12,"Index":8,"Attempt":0,"Launch Time":1516300395165,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":11,"Index":7,"Attempt":0,"Launch Time":1516300395072,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395165,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":323898,"Value":32261378,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":21,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":1023,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":7560,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":7637,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":21689428,"Value":1100141305,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":77,"Value":4945,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4239884,"Value":1015349838,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":3645,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":4239884,"Executor Run Time":77,"Executor CPU Time":21689428,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":323898,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":13,"Index":9,"Attempt":0,"Launch Time":1516300395200,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":10,"Index":1,"Attempt":1,"Launch Time":1516300395069,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395201,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":301705,"Value":32563083,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":24,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":144,"Value":1167,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":8640,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":5,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1091,"Value":8728,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":20826337,"Value":1120967642,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":76,"Value":5021,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4598966,"Value":1019948804,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":3650,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":5,"Executor Deserialize CPU Time":4598966,"Executor Run Time":76,"Executor CPU Time":20826337,"Result Size":1091,"JVM GC Time":0,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":144,"Shuffle Write Time":301705,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":12,"Index":8,"Attempt":0,"Launch Time":1516300395165,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395225,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":319101,"Value":32882184,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":27,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":1314,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":9720,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":9776,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":21657558,"Value":1142625200,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":34,"Value":5055,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4010338,"Value":1023959142,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":3654,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":4010338,"Executor Run Time":34,"Executor CPU Time":21657558,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":319101,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":13,"Index":9,"Attempt":0,"Launch Time":1516300395200,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395276,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":369513,"Value":33251697,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":30,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":1461,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":10800,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":10824,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":20585619,"Value":1163210819,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":25,"Value":5080,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":5860574,"Value":1029819716,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":25,"Value":3679,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":25,"Executor Deserialize CPU Time":5860574,"Executor Run Time":25,"Executor CPU Time":20585619,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":369513,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :27","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :27","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :27","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300392658,"Completion Time":1516300395279,"Accumulables":[{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Value":5080,"Internal":true,"Count Failed Values":true},{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Value":33251697,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Value":370,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Value":10824,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Value":1029819716,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":30,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Value":10800,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":1461,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Value":1163210819,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Value":5,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Value":3679,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :30","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :30","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:30)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300395292,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":14,"Index":0,"Attempt":0,"Launch Time":1516300395302,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":15,"Index":1,"Attempt":0,"Launch Time":1516300395303,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":16,"Index":3,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":17,"Index":4,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":18,"Index":5,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":19,"Index":6,"Attempt":0,"Launch Time":1516300395525,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":17,"Index":4,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395525,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":1134,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":52455999,"Value":52455999,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":95,"Value":95,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":23136577,"Value":23136577,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":82,"Value":82,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":82,"Executor Deserialize CPU Time":23136577,"Executor Run Time":95,"Executor CPU Time":52455999,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":20,"Index":7,"Attempt":0,"Launch Time":1516300395575,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":19,"Index":6,"Attempt":0,"Launch Time":1516300395525,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395576,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":2268,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":13617615,"Value":66073614,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":29,"Value":124,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3469612,"Value":26606189,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":86,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3469612,"Executor Run Time":29,"Executor CPU Time":13617615,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":21,"Index":8,"Attempt":0,"Launch Time":1516300395581,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":18,"Index":5,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395581,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":3402,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":55540208,"Value":121613822,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":179,"Value":303,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":22400065,"Value":49006254,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":78,"Value":164,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":78,"Executor Deserialize CPU Time":22400065,"Executor Run Time":179,"Executor CPU Time":55540208,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":22,"Index":9,"Attempt":0,"Launch Time":1516300395593,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":16,"Index":3,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395593,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":4536,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":52311573,"Value":173925395,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":153,"Value":456,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":20519033,"Value":69525287,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":67,"Value":231,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":67,"Executor Deserialize CPU Time":20519033,"Executor Run Time":153,"Executor CPU Time":52311573,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":20,"Index":7,"Attempt":0,"Launch Time":1516300395575,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395660,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":5670,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":11294260,"Value":185219655,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":33,"Value":489,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3570887,"Value":73096174,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":235,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3570887,"Executor Run Time":33,"Executor CPU Time":11294260,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":22,"Index":9,"Attempt":0,"Launch Time":1516300395593,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395669,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":6804,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":12983732,"Value":198203387,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":44,"Value":533,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3518757,"Value":76614931,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":239,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3518757,"Executor Run Time":44,"Executor CPU Time":12983732,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":21,"Index":8,"Attempt":0,"Launch Time":1516300395581,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395674,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":7938,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":14706240,"Value":212909627,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":64,"Value":597,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":7698059,"Value":84312990,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":21,"Value":260,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":21,"Executor Deserialize CPU Time":7698059,"Executor Run Time":64,"Executor CPU Time":14706240,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":23,"Index":2,"Attempt":0,"Launch Time":1516300395686,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":14,"Index":0,"Attempt":0,"Launch Time":1516300395302,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395687,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":10,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":52,"Value":52,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":195,"Value":195,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":292,"Value":292,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":4,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":6,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":944,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":9224,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":91696783,"Value":304606410,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":221,"Value":818,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":24063461,"Value":108376451,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":150,"Value":410,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":150,"Executor Deserialize CPU Time":24063461,"Executor Run Time":221,"Executor CPU Time":91696783,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":6,"Local Blocks Fetched":4,"Fetch Wait Time":52,"Remote Bytes Read":292,"Remote Bytes Read To Disk":0,"Local Bytes Read":195,"Total Records Read":10},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":15,"Index":1,"Attempt":0,"Launch Time":1516300395303,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395687,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":10,"Value":20,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":107,"Value":159,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":244,"Value":439,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":243,"Value":535,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":5,"Value":9,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":5,"Value":11,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":10510,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":91683507,"Value":396289917,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":289,"Value":1107,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":22106726,"Value":130483177,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":79,"Value":489,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":79,"Executor Deserialize CPU Time":22106726,"Executor Run Time":289,"Executor CPU Time":91683507,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":5,"Local Blocks Fetched":5,"Fetch Wait Time":107,"Remote Bytes Read":243,"Remote Bytes Read To Disk":0,"Local Bytes Read":244,"Total Records Read":10},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":23,"Index":2,"Attempt":0,"Launch Time":1516300395686,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395728,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":10,"Value":30,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":159,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":195,"Value":634,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":292,"Value":827,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":4,"Value":13,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":6,"Value":17,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":11796,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":17607810,"Value":413897727,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":33,"Value":1140,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2897647,"Value":133380824,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":491,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2897647,"Executor Run Time":33,"Executor CPU Time":17607810,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":6,"Local Blocks Fetched":4,"Fetch Wait Time":0,"Remote Bytes Read":292,"Remote Bytes Read To Disk":0,"Local Bytes Read":195,"Total Records Read":10},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :30","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :30","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:30)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300395292,"Completion Time":1516300395728,"Accumulables":[{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":159,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Value":133380824,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Value":11796,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":827,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":634,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Value":491,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Value":2832,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":13,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Value":413897727,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Value":1140,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":17,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Value":30,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1516300395734,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerApplicationEnd","Timestamp":1516300707938} diff --git a/omniadvisor/src/test/resources/test-spark.conf b/omniadvisor/src/test/resources/test-spark.conf new file mode 100644 index 0000000000000000000000000000000000000000..6cbe2baeb6e87dde158fa3a7d5fab538d4c8c1dc --- /dev/null +++ b/omniadvisor/src/test/resources/test-spark.conf @@ -0,0 +1 @@ +spark.master yarn \ No newline at end of file diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ViewMetadata.scala b/omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ViewMetadata.scala deleted file mode 100644 index a3ab16e76d12e1d5a6903a238ae029be8b523486..0000000000000000000000000000000000000000 --- a/omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ViewMetadata.scala +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark.util - -import com.huawei.boostkit.spark.conf.OmniCachePluginConfig._ -import java.util.concurrent.ConcurrentHashMap -import scala.collection.mutable - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog} -import org.apache.spark.sql.catalyst.optimizer.rules.RewriteTime -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RepartitionByExpression, SubqueryAlias} - -object ViewMetadata extends RewriteHelper { - - val viewToViewQueryPlan = new ConcurrentHashMap[String, LogicalPlan]() - - val viewToTablePlan = new ConcurrentHashMap[String, LogicalPlan]() - - val viewToContainsTables = new ConcurrentHashMap[String, Set[TableEqual]]() - - val tableToViews = new ConcurrentHashMap[String, mutable.Set[String]]() - - var spark: SparkSession = _ - - val STATUS_UN_LOAD = "UN_LOAD" - val STATUS_LOADING = "LOADING" - val STATUS_LOADED = "LOADED" - - var status: String = STATUS_UN_LOAD - - def setSpark(sparkSession: SparkSession): Unit = { - spark = sparkSession - status = STATUS_LOADING - } - - def saveViewMetadataToMap(catalogTable: CatalogTable): Unit = this.synchronized { - // if QUERY_REWRITE_ENABLED is false, doesn't load ViewMetadata - if (!catalogTable.properties.getOrElse(MV_REWRITE_ENABLED, "false").toBoolean) { - return - } - - val viewQuerySql = catalogTable.properties.getOrElse(MV_QUERY_ORIGINAL_SQL, "") - if (viewQuerySql.isEmpty) { - logError(s"mvTable: ${catalogTable.identifier.quotedString}'s viewQuerySql is empty!") - return - } - - // preserve preDatabase and set curDatabase - val preDatabase = spark.catalog.currentDatabase - val curDatabase = catalogTable.properties.getOrElse(MV_QUERY_ORIGINAL_SQL_CUR_DB, "") - if (curDatabase.isEmpty) { - logError(s"mvTable: ${catalogTable.identifier.quotedString}'s curDatabase is empty!") - return - } - try { - spark.sessionState.catalogManager.setCurrentNamespace(Array(curDatabase)) - - // db.table - val tableName = catalogTable.identifier.quotedString - val viewTablePlan = RewriteTime - .withTimeStat("viewTablePlan") { - spark.table(tableName).queryExecution.analyzed match { - case SubqueryAlias(_, child) => child - case a@_ => a - } - } - var viewQueryPlan = RewriteTime - .withTimeStat("viewQueryPlan") { - spark.sql(viewQuerySql).queryExecution.analyzed - } - viewQueryPlan = viewQueryPlan match { - case RepartitionByExpression(_, child, _) => - child - case _ => - viewQueryPlan - } - // reset preDatabase - spark.sessionState.catalogManager.setCurrentNamespace(Array(preDatabase)) - - // spark_catalog.db.table - val viewName = catalogTable.identifier.toString() - - // mappedViewQueryPlan and mappedViewContainsTables - val (mappedViewQueryPlan, mappedViewContainsTables) = RewriteTime - .withTimeStat("extractTables") { - extractTables(viewQueryPlan) - } - - mappedViewContainsTables - .foreach { mappedViewContainsTable => - val name = mappedViewContainsTable.tableName - val views = tableToViews.getOrDefault(name, mutable.Set.empty) - views += viewName - tableToViews.put(name, views) - } - - // extract view query project's Attr and replace view table's Attr by query project's Attr - // match function is attributeReferenceEqualSimple, by name and data type - // Attr of table cannot used, because same Attr in view query and view table, - // it's table is different. - val mappedViewTablePlan = RewriteTime - .withTimeStat("mapTablePlanAttrToQuery") { - mapTablePlanAttrToQuery(viewTablePlan, mappedViewQueryPlan) - } - - viewToContainsTables.put(viewName, mappedViewContainsTables) - viewToViewQueryPlan.putIfAbsent(viewName, mappedViewQueryPlan) - viewToTablePlan.putIfAbsent(viewName, mappedViewTablePlan) - } catch { - case e: Throwable => - logDebug(s"Failed to saveViewMetadataToMap,errmsg: ${e.getMessage}") - // reset preDatabase - spark.sessionState.catalogManager.setCurrentNamespace(Array(preDatabase)) - } - } - - def isEmpty: Boolean = { - viewToTablePlan.isEmpty - } - - def isViewExists(viewIdentifier: String): Boolean = { - viewToTablePlan.containsKey(viewIdentifier) - } - - def addCatalogTableToCache(table: CatalogTable): Unit = this.synchronized { - saveViewMetadataToMap(table) - } - - def removeMVCache(tableName: TableIdentifier): Unit = this.synchronized { - val viewName = tableName.toString() - viewToContainsTables.remove(viewName) - viewToViewQueryPlan.remove(viewName) - viewToTablePlan.remove(viewName) - tableToViews.forEach { (key, value) => - if (value.contains(viewName)) { - value -= viewName - tableToViews.put(key, value) - } - } - } - - def init(sparkSession: SparkSession): Unit = { - if (status == STATUS_LOADED) { - return - } - - setSpark(sparkSession) - forceLoad() - status = STATUS_LOADED - } - - def forceLoad(): Unit = this.synchronized { - val catalog = spark.sessionState.catalog - - // load from all db - for (db <- catalog.listDatabases()) { - val tables = RewriteTime.withTimeStat("loadTable") { - omniCacheFilter(catalog, db) - } - RewriteTime.withTimeStat("saveViewMetadataToMap") { - tables.foreach(tableData => saveViewMetadataToMap(tableData)) - } - } - } - - def omniCacheFilter(catalog: SessionCatalog, - mvDataBase: String): Seq[CatalogTable] = { - try { - val allTables = catalog.listTables(mvDataBase) - catalog.getTablesByName(allTables).filter { tableData => - tableData.properties.contains(MV_QUERY_ORIGINAL_SQL) - } - } catch { - // if db exists a table hive materialized view, will throw analysis exception - case e: Throwable => - logDebug(s"Failed to listTables in $mvDataBase, errmsg: ${e.getMessage}") - Seq.empty[CatalogTable] - } - } -} diff --git a/omnidata/omnidata-hive-connector/build.sh b/omnidata/omnidata-hive-connector/build.sh deleted file mode 100644 index 98c426e22cc430cc1268816c9355bc13d98b8c9f..0000000000000000000000000000000000000000 --- a/omnidata/omnidata-hive-connector/build.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash -mvn clean package -jar_name=`ls -n connector/target/*.jar | grep hive-exec | awk -F ' ' '{print$9}' | awk -F '/' '{print$3}'` -dir_name=`ls -n connector/target/*.jar | grep hive-exec | awk -F ' ' '{print$9}' | awk -F '/' '{print$3}' | awk -F '.jar' '{print$1}'` -rm -r $dir_name -rm -r $dir_name.zip -mkdir -p $dir_name -cp connector/target/$jar_name $dir_name -cd $dir_name -wget https://mirrors.huaweicloud.com/repository/maven/org/bouncycastle/bcpkix-jdk15on/1.68/bcpkix-jdk15on-1.68.jar -wget https://mirrors.huaweicloud.com/repository/maven/org/bouncycastle/bcprov-jdk15on/1.68/bcprov-jdk15on-1.68.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/google/guava/guava/31.1-jre/guava-31.1-jre.jar -wget https://mirrors.huaweicloud.com/repository/maven/io/hetu/core/hetu-transport/1.6.1/hetu-transport-1.6.1.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/core/jackson-annotations/2.12.4/jackson-annotations-2.12.4.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/core/jackson-core/2.12.4/jackson-core-2.12.4.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/core/jackson-databind/2.12.4/jackson-databind-2.12.4.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-guava/2.12.4/jackson-datatype-guava-2.12.4.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-jdk8/2.12.4/jackson-datatype-jdk8-2.12.4.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-joda/2.12.4/jackson-datatype-joda-2.12.4.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-jsr310/2.12.4/jackson-datatype-jsr310-2.12.4.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/module/jackson-module-parameter-names/2.12.4/jackson-module-parameter-names-2.12.4.jar -wget https://mirrors.huaweicloud.com/repository/maven/org/jasypt/jasypt/1.9.3/jasypt-1.9.3.jar -wget https://mirrors.huaweicloud.com/repository/maven/org/openjdk/jol/jol-core/0.2/jol-core-0.2.jar -wget https://repo1.maven.org/maven2/io/airlift/joni/2.1.5.3/joni-2.1.5.3.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/esotericsoftware/kryo-shaded/4.0.2/kryo-shaded-4.0.2.jar -wget https://mirrors.huaweicloud.com/repository/maven/io/airlift/log/0.193/log-0.193.jar -wget https://mirrors.huaweicloud.com/repository/maven/io/perfmark/perfmark-api/0.23.0/perfmark-api-0.23.0.jar -wget https://mirrors.huaweicloud.com/repository/maven/io/hetu/core/presto-main/1.6.1/presto-main-1.6.1.jar -wget https://mirrors.huaweicloud.com/repository/maven/io/hetu/core/presto-spi/1.6.1/presto-spi-1.6.1.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/google/protobuf/protobuf-java/3.12.0/protobuf-java-3.12.0.jar -wget https://mirrors.huaweicloud.com/repository/maven/io/airlift/slice/0.38/slice-0.38.jar -cd .. -zip -r -o $dir_name.zip $dir_name -rm -r $dir_name \ No newline at end of file diff --git a/omnidata/omnidata-hive-connector/hive_build.sh b/omnidata/omnidata-hive-connector/hive_build.sh new file mode 100644 index 0000000000000000000000000000000000000000..d32e27957388b29725cadb3ba05fd4cf0a8f56cb --- /dev/null +++ b/omnidata/omnidata-hive-connector/hive_build.sh @@ -0,0 +1,32 @@ +#!/bin/bash +mvn clean package +jar_name=`ls -n connector/target/*.jar | grep hive-exec | awk -F ' ' '{print$9}' | awk -F '/' '{print$3}'` +dir_name=`ls -n connector/target/*.jar | grep hive-exec | awk -F ' ' '{print$9}' | awk -F '/' '{print$3}' | awk -F '.jar' '{print$1}'` +if [ -d "${dir_name}" ];then rm -rf ${dir_name}; fi +if [ -d "${dir_name}.zip" ];then rm -rf ${dir_name}.zip; fi +mkdir -p $dir_name +cp connector/target/$jar_name $dir_name +cd $dir_name +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/bcpkix-jdk15on/1.68/package/bcpkix-jdk15on-1.68.jar +wget --proxy=off --no-check-certificate https://cmc-hgh-artifactory.cmc.tools.huawei.com/artifactory/opensource_general/guava/31.1-jre/package/guava-31.1-jre.jar +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/hetu-transport/1.6.1/package/hetu-transport-1.6.1.jar +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/jackson-annotations/2.12.4/package/jackson-annotations-2.12.4.jar +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/jackson-core/2.12.4/package/jackson-core-2.12.4.jar +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/jackson-databind/2.12.4/package/jackson-databind-2.12.4.jar +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/jackson-datatype-guava/2.12.4/package/jackson-datatype-guava-2.12.4.jar +wget --proxy=off --no-check-certificate https://cmc-hgh-artifactory.cmc.tools.huawei.com/artifactory/opensource_general/jackson-datatype-jdk8/2.12.4/package/jackson-datatype-jdk8-2.12.4.jar +wget --proxy=off --no-check-certificate https://cmc-hgh-artifactory.cmc.tools.huawei.com/artifactory/opensource_general/Jackson-datatype-Joda/2.12.4/package/jackson-datatype-joda-2.12.4.jar +wget --proxy=off --no-check-certificate https://cmc-hgh-artifactory.cmc.tools.huawei.com/artifactory/opensource_general/jackson-datatype-jsr310/2.12.4/package/jackson-datatype-jsr310-2.12.4.jar +wget --proxy=off --no-check-certificate https://cmc-hgh-artifactory.cmc.tools.huawei.com/artifactory/opensource_general/jackson-module-parameter-names/2.12.4/package/jackson-module-parameter-names-2.12.4.jar +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/jasypt/1.9.3/package/jasypt-1.9.3.jar +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/jol-core/0.2/package/jol-core-0.2.jar +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/joni/2.1.5.3/package/joni-2.1.5.3.jar +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/kryo-shaded/4.0.2/package/kryo-shaded-4.0.2.jar +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/log/0.193/package/log-0.193.jar +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/perfmark-api/0.23.0/package/perfmark-api-0.23.0.jar +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/presto-main/1.6.1/package/presto-main-1.6.1.jar +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/presto-spi/1.6.1/package/presto-spi-1.6.1.jar +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/protobuf-java/3.12.0/package/protobuf-java-3.12.0.jar +wget --proxy=off --no-check-certificate https://cmc.cloudartifact.szv.dragon.tools.huawei.com/artifactory/opensource_general/slice/0.38/package/slice-0.38.jar +cd .. +zip -r -o $dir_name.zip $dir_name \ No newline at end of file diff --git a/omnidata/omnidata-server-lib/pom.xml b/omnidata/omnidata-server-lib/pom.xml index a782bd38e282d3132443ba1c01f14822f9aa255a..13ebeb82298dbca42b7598a5a6dfdf65b1f90144 100644 --- a/omnidata/omnidata-server-lib/pom.xml +++ b/omnidata/omnidata-server-lib/pom.xml @@ -7,18 +7,24 @@ com.huawei.boostkit omnidata-server-lib pom - 1.4.0 + 1.5.0 ${os.detected.arch} 2.11.4 1.2.3 - 1.6.1 + 1.9.0 206 2.12.0 + 0.9 + + net.openhft + zero-allocation-hashing + ${dep.net.openhft.version} + com.fasterxml.jackson.core jackson-databind @@ -96,11 +102,6 @@ - - com.google.guava - guava - 30.0-jre - io.prestosql.hadoop hadoop-apache @@ -231,6 +232,21 @@ + + io.hetu.core + presto-expressions + ${dep.hetu.version} + + + io.hetu.core + hetu-common + ${dep.hetu.version} + + + com.google.guava + guava + 30.0-jre + io.airlift units diff --git a/omnidata/omnidata-spark-connector/README.md b/omnidata/omnidata-spark-connector/README.md index c773c416e83311847a6aea19163aaf83aa5d9492..de2c8b8c72f7ae75d7346a6306e4dbe70c901bee 100644 --- a/omnidata/omnidata-spark-connector/README.md +++ b/omnidata/omnidata-spark-connector/README.md @@ -5,7 +5,7 @@ Introduction ============ -The omnidata spark connector library running on Kunpeng processors is a Spark SQL plugin that pushes computing-side operators to storage nodes for computing. It is developed based on original APIs of Apache [Spark 3.0.0](https://github.com/apache/spark/tree/v3.0.0). This library applies to the big data storage separation scenario or large-scale fusion scenario where a large number of compute nodes read data from remote nodes. In this scenario, a large amount of raw data is transferred from storage nodes to compute nodes over the network for processing, resulting in a low proportion of valid data and a huge waste of network bandwidth. You can find the latest documentation, including a programming guide, on the project web page. This README file only contains basic setup instructions. +The omnidata spark connector library running on Kunpeng processors is a Spark SQL plugin that pushes computing-side operators to storage nodes for computing. It is developed based on original APIs of Apache [Spark 3.1.1](https://github.com/apache/spark/tree/v3.1.1). This library applies to the big data storage separation scenario or large-scale fusion scenario where a large number of compute nodes read data from remote nodes. In this scenario, a large amount of raw data is transferred from storage nodes to compute nodes over the network for processing, resulting in a low proportion of valid data and a huge waste of network bandwidth. You can find the latest documentation, including a programming guide, on the project web page. This README file only contains basic setup instructions. Building And Packageing diff --git a/omnidata/omnidata-spark-connector/build.sh b/omnidata/omnidata-spark-connector/build.sh deleted file mode 100644 index 7a528d4476b1a6d758054576f289d09ce8ae9c3e..0000000000000000000000000000000000000000 --- a/omnidata/omnidata-spark-connector/build.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash -mvn clean package -jar_name=`ls -n connector/target/*.jar | grep omnidata-spark | awk -F ' ' '{print$9}' | awk -F '/' '{print$3}'` -dir_name=`ls -n connector/target/*.jar | grep omnidata-spark | awk -F ' ' '{print$9}' | awk -F '/' '{print$3}' | awk -F '.jar' '{print$1}'` -rm -r $dir_name-aarch64 -rm -r $dir_name-aarch64.zip -mkdir -p $dir_name-aarch64 -cp connector/target/$jar_name $dir_name-aarch64 -cd $dir_name-aarch64 -wget https://mirrors.huaweicloud.com/repository/maven/org/bouncycastle/bcpkix-jdk15on/1.68/bcpkix-jdk15on-1.68.jar -wget https://mirrors.huaweicloud.com/repository/maven/org/apache/curator/curator-client/2.12.0/curator-client-2.12.0.jar -wget https://mirrors.huaweicloud.com/repository/maven/org/apache/curator/curator-framework/2.12.0/curator-framework-2.12.0.jar -wget https://mirrors.huaweicloud.com/repository/maven/org/apache/curator/curator-recipes/2.12.0/curator-recipes-2.12.0.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/alibaba/fastjson/1.2.76/fastjson-1.2.76.jar -wget https://mirrors.huaweicloud.com/repository/maven/de/ruedigermoeller/fst/2.57/fst-2.57.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/google/guava/guava/26.0-jre/guava-26.0-jre.jar -wget https://mirrors.huaweicloud.com/repository/maven/io/hetu/core/hetu-transport/1.6.1/hetu-transport-1.6.1.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-guava/2.12.4/jackson-datatype-guava-2.12.4.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-jdk8/2.12.4/jackson-datatype-jdk8-2.12.4.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-joda/2.12.4/jackson-datatype-joda-2.12.4.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/datatype/jackson-datatype-jsr310/2.12.4/jackson-datatype-jsr310-2.12.4.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/fasterxml/jackson/module/jackson-module-parameter-names/2.12.4/jackson-module-parameter-names-2.12.4.jar -wget https://mirrors.huaweicloud.com/repository/maven/org/jasypt/jasypt/1.9.3/jasypt-1.9.3.jar -wget https://mirrors.huaweicloud.com/repository/maven/org/openjdk/jol/jol-core/0.2/jol-core-0.2.jar -wget https://repo1.maven.org/maven2/io/airlift/joni/2.1.5.3/joni-2.1.5.3.jar -wget https://mirrors.huaweicloud.com/repository/maven/io/airlift/log/0.193/log-0.193.jar -wget https://mirrors.huaweicloud.com/repository/maven/io/perfmark/perfmark-api/0.23.0/perfmark-api-0.23.0.jar -wget https://mirrors.huaweicloud.com/repository/maven/io/hetu/core/presto-main/1.6.1/presto-main-1.6.1.jar -wget https://mirrors.huaweicloud.com/repository/maven/io/hetu/core/presto-spi/1.6.1/presto-spi-1.6.1.jar -wget https://mirrors.huaweicloud.com/repository/maven/com/google/protobuf/protobuf-java/3.12.0/protobuf-java-3.12.0.jar -wget https://mirrors.huaweicloud.com/repository/maven/io/airlift/slice/0.38/slice-0.38.jar -cd .. -zip -r -o $dir_name-aarch64.zip $dir_name-aarch64 -rm -r $dir_name-aarch64 \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/connector/pom.xml b/omnidata/omnidata-spark-connector/connector/pom.xml index 4fd2668b10b13f63ccef418c56af7ee9d1bfe8ce..6062530d9b26a1182ed29572f06b5ff811c7eaba 100644 --- a/omnidata/omnidata-spark-connector/connector/pom.xml +++ b/omnidata/omnidata-spark-connector/connector/pom.xml @@ -5,12 +5,12 @@ org.apache.spark omnidata-spark-connector-root - 1.4.0 + 1.5.0 4.0.0 boostkit-omnidata-spark-sql_2.12-3.1.1 - 1.4.0 + 1.5.0 boostkit omnidata spark sql 2021 jar @@ -24,8 +24,14 @@ 2.12.0 1.6.1 1.35.0 + 2.12 + + io.hetu.core + presto-expressions + ${dep.hetu.version} + org.apache.spark spark-hive_2.12 @@ -55,7 +61,7 @@ com.huawei.boostkit boostkit-omnidata-stub - 1.4.0 + 1.5.0 compile @@ -73,9 +79,111 @@ curator-recipes ${dep.curator.version} + + + + io.airlift + log + 206 + test + + + io.airlift + stats + 206 + + + org.apache.lucene + lucene-analyzers-common + 7.2.1 + + + it.unimi.dsi + fastutil + 6.5.9 + + + io.airlift + bytecode + 1.2 + + + io.hetu.core + presto-parser + ${dep.hetu.version} + test + + + io.airlift + json + 206 + + + org.testng + testng + 6.10 + test + + + org.mockito + mockito-core + 1.9.5 + test + + + objenesis + org.objenesis + + + + + org.scalatest + scalatest_${scala.binary.version} + 3.2.3 + test + + + org.apache.spark + spark-core_${scala.binary.version} + test-jar + test + ${spark.version} + + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + ${spark.version} + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + test-jar + test + ${spark.version} + + + org.apache.spark + spark-hive-thriftserver_${scala.binary.version} + test + ${spark.version} + + + org.apache.spark + spark-hive_${scala.binary.version} + test + ${spark.version} + + + org.apache.hive + hive-cli + 2.3.7 + src/main/scala + src/test/java org.codehaus.mojo diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/NdpConnectorUtils.java b/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/NdpConnectorUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..9d0937625f48aa4ca15a601f274e0e05760944fc --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/NdpConnectorUtils.java @@ -0,0 +1,188 @@ +/* + * Copyright (C) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omnidata.spark; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.Inet4Address; +import java.net.InetAddress; +import java.net.NetworkInterface; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.Set; +import java.util.regex.Pattern; + +public class NdpConnectorUtils { + + private static final Logger LOG = LoggerFactory.getLogger(NdpConnectorUtils.class); + public static Set getIpAddress() { + Set ipSet = new HashSet<>(); + try { + Enumeration allNetInterfaces = NetworkInterface.getNetworkInterfaces(); + while (allNetInterfaces.hasMoreElements()) { + NetworkInterface netInterface = (NetworkInterface) allNetInterfaces.nextElement(); + if (netInterface.isLoopback() || netInterface.isVirtual() || !netInterface.isUp()) { + continue; + } + Enumeration addresses = netInterface.getInetAddresses(); + while (addresses.hasMoreElements()) { + InetAddress ip = addresses.nextElement(); + if (ip instanceof Inet4Address) { + ipSet.add(ip.getHostAddress()); + } + } + } + } catch (Exception e) { + LOG.error("getIpAddress exception:", e); + } + return ipSet; + } + + public static boolean isNumeric(String str) { + return str != null && str.chars().allMatch(Character::isDigit); + } + + private static int getIntSysEnv(String envName, int defaultVal) { + String val = System.getenv(envName); + if (isNumeric(val) && Integer.parseInt(val) > 0) { + return Integer.parseInt(val); + } + return defaultVal; + } + + private static String getIntStrSysEnv(String envName, String defaultVal) { + String val = System.getenv(envName); + if (isNumeric(val) && Integer.parseInt(val) > 0) { + return val; + } + return defaultVal; + } + + public static boolean getNdpEnable() { + String isEnable = System.getenv("NDP_PLUGIN_ENABLE"); + return isEnable != null && isEnable.equals("true"); + } + + public static int getPushDownTaskTotal(int taskTotal) { + return getIntSysEnv("DEFAULT_PUSHDOWN_TASK", taskTotal); + } + + public static String getNdpNumPartitionsStr(String numStr) { + return getIntStrSysEnv("DEFAULT_NDP_NUM_PARTITIONS", numStr); + } + + public static int getCountTaskTotal(int taskTotal) { + return getIntSysEnv("COUNT_TASK_TOTAL", taskTotal); + } + + public static String getCountMaxPartSize(String size) { + return System.getenv("COUNT_MAX_PART_SIZE") != null ? + System.getenv("COUNT_MAX_PART_SIZE") : size; + } + + public static int getCountDistinctTaskTotal(int taskTotal) { + return getIntSysEnv("COUNT_DISTINCT_TASK_TOTAL", taskTotal); + } + + public static String getSMJMaxPartSize(String size) { + return System.getenv("SMJ_MAX_PART_SIZE") != null ? + System.getenv("SMJ_MAX_PART_SIZE") : size; + } + + public static int getSMJNumPartitions(int numPartitions) { + return getIntSysEnv("SMJ_NUM_PARTITIONS", numPartitions); + } + + public static int getOmniColumnarNumPartitions(int numPartitions) { + return getIntSysEnv("OMNI_COLUMNAR_PARTITIONS", numPartitions); + } + + public static int getOmniColumnarTaskCount(int taskTotal) { + return getIntSysEnv("OMNI_COLUMNAR_TASK_TOTAL", taskTotal); + } + + public static int getFilterPartitions(int numPartitions) { + return getIntSysEnv("FILTER_COLUMNAR_PARTITIONS", numPartitions); + } + + public static int getFilterTaskCount(int taskTotal) { + return getIntSysEnv("FILTER_TASK_TOTAL", taskTotal); + } + + public static String getSortRepartitionSizeStr(String sizeStr) { + return System.getenv("SORT_REPARTITION_SIZE") != null ? + System.getenv("SORT_REPARTITION_SIZE") : sizeStr; + } + + public static String getCastDecimalPrecisionStr(String numStr) { + return System.getenv("CAST_DECIMAL_PRECISION") != null ? + System.getenv("CAST_DECIMAL_PRECISION") : numStr; + } + + public static String getNdpMaxPtFactorStr(String numStr) { + return System.getenv("NDP_MAX_PART_FACTOR") != null ? + System.getenv("NDP_MAX_PART_FACTOR") : numStr; + } + + public static String getCountAggMaxFilePtBytesStr(String BytesStr) { + return System.getenv("COUNT_AGG_MAX_FILE_BYTES") != null ? + System.getenv("COUNT_AGG_MAX_FILE_BYTES") : BytesStr; + } + + public static String getAvgAggMaxFilePtBytesStr(String BytesStr) { + return System.getenv("AVG_AGG_MAX_FILE_BYTES") != null ? + System.getenv("AVG_AGG_MAX_FILE_BYTES") : BytesStr; + } + + public static String getBhjMaxFilePtBytesStr(String BytesStr) { + return System.getenv("BHJ_MAX_FILE_BYTES") != null ? + System.getenv("BHJ_MAX_FILE_BYTES") : BytesStr; + } + + public static String getGroupMaxFilePtBytesStr(String BytesStr) { + return System.getenv("GROUP_MAX_FILE_BYTES") != null ? + System.getenv("GROUP_MAX_FILE_BYTES") : BytesStr; + } + + public static String getMixSqlBaseMaxFilePtBytesStr(String BytesStr) { + return System.getenv("MIX_SQL_BASE_MAX_FILE_BYTES") != null ? + System.getenv("MIX_SQL_BASE_MAX_FILE_BYTES") : BytesStr; + } + + public static String getMixSqlAccurateMaxFilePtBytesStr(String BytesStr) { + return System.getenv("MIX_SQL_ACCURATE_MAX_FILE_BYTES") != null ? + System.getenv("MIX_SQL_ACCURATE_MAX_FILE_BYTES") : BytesStr; + } + + public static String getAggShufflePartitionsStr(String BytesStr) { + return System.getenv("AGG_SHUFFLE_PARTITIONS") != null ? + System.getenv("AGG_SHUFFLE_PARTITIONS") : BytesStr; + } + + public static String getShufflePartitionsStr(String BytesStr) { + return System.getenv("SHUFFLE_PARTITIONS") != null ? + System.getenv("SHUFFLE_PARTITIONS") : BytesStr; + } + + public static String getSortShufflePartitionsStr(String BytesStr) { + return System.getenv("SORT_SHUFFLE_PARTITIONS") != null ? + System.getenv("SORT_SHUFFLE_PARTITIONS") : BytesStr; + } +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/OperatorPageDeRunLength.java b/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/OperatorPageDeRunLength.java new file mode 100644 index 0000000000000000000000000000000000000000..8dbde71f4428d026c0d0cea583c7d7a16efe411a --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/OperatorPageDeRunLength.java @@ -0,0 +1,34 @@ +/* + * Copyright (C) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omnidata.spark; + +import org.apache.spark.sql.execution.vectorized.OmniColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; + +/** + * DeCompress RunLength for combine with operator + * + * @since 2023-07-20 + */ +public class OperatorPageDeRunLength extends PageDeRunLength { + @Override + protected WritableColumnVector getColumnVector(int positionCount, WritableColumnVector writableColumnVector) { + return new OmniColumnVector(positionCount, writableColumnVector.dataType(), true); + } +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/OperatorPageDecoding.java b/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/OperatorPageDecoding.java new file mode 100644 index 0000000000000000000000000000000000000000..afaa26066f35457cf798b6ee7e1291c0842cecb7 --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/OperatorPageDecoding.java @@ -0,0 +1,75 @@ +/* + * Copyright (C) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omnidata.spark; + +import com.huawei.boostkit.omnidata.decode.type.*; +import io.airlift.slice.SliceInput; +import org.apache.spark.sql.execution.util.SparkMemoryUtils; +import org.apache.spark.sql.execution.vectorized.OmniColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DecimalType; + +import java.lang.reflect.InvocationTargetException; +import java.util.Optional; + +/** + * Decode data to spark writableColumnVector for combine with operator + * + * @since 2023-07-20 + */ +public class OperatorPageDecoding extends PageDecoding { + + static { + SparkMemoryUtils.init(); + } + + public OperatorPageDecoding(String fileFormat) { + super(fileFormat); + } + + @Override + public Optional decodeVariableWidth(Optional type, SliceInput sliceInput) { + int positionCount = sliceInput.readInt(); + return decodeVariableWidthBase(type, sliceInput, + new OmniColumnVector(positionCount, DataTypes.StringType, true), positionCount); + } + + @Override + public Optional decodeRunLength(Optional type, SliceInput sliceInput) + throws InvocationTargetException, IllegalAccessException { + return decodeRunLengthBase(type, sliceInput, new OperatorPageDeRunLength()); + } + + @Override + protected WritableColumnVector createColumnVectorForDecimal(int positionCount, DecimalType decimalType) { + return new OmniColumnVector(positionCount, decimalType, true); + } + + @Override + protected Optional decodeSimple( + SliceInput sliceInput, + DataType dataType, + String dataTypeName) { + int positionCount = sliceInput.readInt(); + WritableColumnVector columnVector = new OmniColumnVector(positionCount, dataType, true); + return getWritableColumnVector(sliceInput, positionCount, columnVector, dataTypeName); + } +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/PageDeRunLength.java b/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/PageDeRunLength.java index 7802b7b5ab0fbd0880b123f6ae69ac6cf1acf546..0a3ca5db223d633996eeda21244798c3112c01c5 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/PageDeRunLength.java +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/PageDeRunLength.java @@ -20,7 +20,6 @@ package com.huawei.boostkit.omnidata.spark; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; -import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.DecimalType; @@ -54,7 +53,7 @@ public class PageDeRunLength { public Optional decompressByteArray(int positionCount, WritableColumnVector writableColumnVector) throws Exception { byte value = writableColumnVector.getByte(0); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.ByteType); + WritableColumnVector columnVector = getColumnVector(positionCount, writableColumnVector); if (writableColumnVector.isNullAt(0)) { columnVector.putNulls(0, positionCount); } else { @@ -78,7 +77,7 @@ public class PageDeRunLength { public Optional decompressBooleanArray(int positionCount, WritableColumnVector writableColumnVector) throws Exception { boolean value = writableColumnVector.getBoolean(0); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.BooleanType); + WritableColumnVector columnVector = getColumnVector(positionCount, writableColumnVector); if (writableColumnVector.isNullAt(0)) { columnVector.putNulls(0, positionCount); } else { @@ -102,7 +101,7 @@ public class PageDeRunLength { public Optional decompressIntArray(int positionCount, WritableColumnVector writableColumnVector) throws Exception { int value = writableColumnVector.getInt(0); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.IntegerType); + WritableColumnVector columnVector = getColumnVector(positionCount, writableColumnVector); if (writableColumnVector.isNullAt(0)) { columnVector.putNulls(0, positionCount); } else { @@ -126,7 +125,7 @@ public class PageDeRunLength { public Optional decompressShortArray(int positionCount, WritableColumnVector writableColumnVector) throws Exception { short value = writableColumnVector.getShort(0); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.ShortType); + WritableColumnVector columnVector = getColumnVector(positionCount, writableColumnVector); if (writableColumnVector.isNullAt(0)) { columnVector.putNulls(0, positionCount); } else { @@ -150,7 +149,7 @@ public class PageDeRunLength { public Optional decompressLongArray(int positionCount, WritableColumnVector writableColumnVector) throws Exception { long value = writableColumnVector.getLong(0); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.LongType); + WritableColumnVector columnVector = getColumnVector(positionCount, writableColumnVector); if (writableColumnVector.isNullAt(0)) { columnVector.putNulls(0, positionCount); } else { @@ -174,7 +173,7 @@ public class PageDeRunLength { public Optional decompressFloatArray(int positionCount, WritableColumnVector writableColumnVector) throws Exception { float value = writableColumnVector.getFloat(0); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.FloatType); + WritableColumnVector columnVector = getColumnVector(positionCount, writableColumnVector); if (writableColumnVector.isNullAt(0)) { columnVector.putNulls(0, positionCount); } else { @@ -198,7 +197,7 @@ public class PageDeRunLength { public Optional decompressDoubleArray(int positionCount, WritableColumnVector writableColumnVector) throws Exception { double value = writableColumnVector.getDouble(0); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.DoubleType); + WritableColumnVector columnVector = getColumnVector(positionCount, writableColumnVector); if (writableColumnVector.isNullAt(0)) { columnVector.putNulls(0, positionCount); } else { @@ -221,7 +220,7 @@ public class PageDeRunLength { */ public Optional decompressVariableWidth(int positionCount, WritableColumnVector writableColumnVector) throws Exception { - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.StringType); + WritableColumnVector columnVector = getColumnVector(positionCount, writableColumnVector); if (writableColumnVector.isNullAt(0)) { columnVector.putNulls(0, positionCount); } else { @@ -247,11 +246,11 @@ public class PageDeRunLength { int precision = ((DecimalType) writableColumnVector.dataType()).precision(); int scale = ((DecimalType) writableColumnVector.dataType()).scale(); Decimal value = writableColumnVector.getDecimal(0, precision, scale); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, writableColumnVector.dataType()); - for (int rowId = 0; rowId < positionCount; rowId++) { - if (writableColumnVector.isNullAt(rowId)) { - columnVector.putNull(rowId); - } else { + WritableColumnVector columnVector = getColumnVector(positionCount, writableColumnVector); + if (writableColumnVector.isNullAt(0)) { + columnVector.putNulls(0, positionCount); + } else { + for (int rowId = 0; rowId < positionCount; rowId++) { columnVector.putDecimal(rowId, value, precision); } } @@ -262,4 +261,8 @@ public class PageDeRunLength { } return Optional.of(columnVector); } -} + + protected WritableColumnVector getColumnVector(int positionCount, WritableColumnVector writableColumnVector) { + return PageDecoding.createColumnVector(positionCount, writableColumnVector.dataType()); + } +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/PageDecoding.java b/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/PageDecoding.java index 5d46338bdb5d2791491d53623dddc6725d588c55..e140dfd7c4125a9b7a18a125b49093e2ecab51a1 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/PageDecoding.java +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/PageDecoding.java @@ -21,7 +21,6 @@ package com.huawei.boostkit.omnidata.spark; import static io.airlift.slice.SizeOf.SIZE_OF_INT; import static java.lang.Double.longBitsToDouble; import static java.lang.Float.intBitsToFloat; -import static org.apache.spark.sql.types.DataTypes.TimestampType; import com.huawei.boostkit.omnidata.decode.AbstractDecoding; import com.huawei.boostkit.omnidata.decode.type.*; @@ -30,13 +29,19 @@ import com.huawei.boostkit.omnidata.exception.OmniDataException; import io.airlift.slice.SliceInput; import io.airlift.slice.Slices; import io.prestosql.spi.type.DateType; - import io.prestosql.spi.type.Decimals; + +import org.apache.spark.sql.catalyst.util.RebaseDateTime; +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.DecimalType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; @@ -53,7 +58,12 @@ import java.util.TimeZone; * @since 2021-03-30 */ public class PageDecoding extends AbstractDecoding> { - private static Field filedElementsAppended; + private static final Logger LOG = LoggerFactory.getLogger(PageDecoding.class); + + /** + * Log appended files. + */ + static Field filedElementsAppended; static { try { @@ -64,73 +74,30 @@ public class PageDecoding extends AbstractDecoding decodeArray(Optional type, SliceInput sliceInput) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException("not support array decode"); } @Override public Optional decodeByteArray(Optional type, SliceInput sliceInput) { - int positionCount = sliceInput.readInt(); - - boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.ByteType); - for (int position = 0; position < positionCount; position++) { - if (valueIsNull == null || !valueIsNull[position]) { - columnVector.putByte(position, sliceInput.readByte()); - } else { - columnVector.putNull(position); - } - } - try { - PageDecoding.filedElementsAppended.set(columnVector, positionCount); - } catch (Exception e) { - throw new OmniDataException(e.getMessage()); - } - return Optional.of(columnVector); + return decodeSimple(sliceInput, DataTypes.ByteType, "byte"); } @Override public Optional decodeBooleanArray(Optional type, SliceInput sliceInput) { - int positionCount = sliceInput.readInt(); - - boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.BooleanType); - for (int position = 0; position < positionCount; position++) { - if (valueIsNull == null || !valueIsNull[position]) { - boolean value = sliceInput.readByte() != 0; - columnVector.putBoolean(position, value); - } else { - columnVector.putNull(position); - } - } - try { - PageDecoding.filedElementsAppended.set(columnVector, positionCount); - } catch (Exception e) { - throw new OmniDataException(e.getMessage()); - } - return Optional.of(columnVector); + return decodeSimple(sliceInput, DataTypes.BooleanType, "boolean"); } @Override public Optional decodeIntArray(Optional type, SliceInput sliceInput) { - int positionCount = sliceInput.readInt(); - - boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.IntegerType); - for (int position = 0; position < positionCount; position++) { - if (valueIsNull == null || !valueIsNull[position]) { - columnVector.putInt(position, sliceInput.readInt()); - } else { - columnVector.putNull(position); - } - } - try { - PageDecoding.filedElementsAppended.set(columnVector, positionCount); - } catch (Exception e) { - throw new OmniDataException(e.getMessage()); - } - return Optional.of(columnVector); + return decodeSimple(sliceInput, DataTypes.IntegerType, "int"); } @Override @@ -140,86 +107,22 @@ public class PageDecoding extends AbstractDecoding decodeShortArray(Optional type, SliceInput sliceInput) { - int positionCount = sliceInput.readInt(); - - boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.ShortType); - for (int position = 0; position < positionCount; position++) { - if (valueIsNull == null || !valueIsNull[position]) { - columnVector.putShort(position, sliceInput.readShort()); - } else { - columnVector.putNull(position); - } - } - try { - PageDecoding.filedElementsAppended.set(columnVector, positionCount); - } catch (Exception e) { - throw new OmniDataException(e.getMessage()); - } - return Optional.of(columnVector); + return decodeSimple(sliceInput, DataTypes.ShortType, "short"); } @Override public Optional decodeLongArray(Optional type, SliceInput sliceInput) { - int positionCount = sliceInput.readInt(); - - boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.LongType); - for (int position = 0; position < positionCount; position++) { - if (valueIsNull == null || !valueIsNull[position]) { - columnVector.putLong(position, sliceInput.readLong()); - } else { - columnVector.putNull(position); - } - } - try { - PageDecoding.filedElementsAppended.set(columnVector, positionCount); - } catch (Exception e) { - throw new OmniDataException(e.getMessage()); - } - return Optional.of(columnVector); + return decodeSimple(sliceInput, DataTypes.LongType, "long"); } @Override public Optional decodeFloatArray(Optional type, SliceInput sliceInput) { - int positionCount = sliceInput.readInt(); - - boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.FloatType); - for (int position = 0; position < positionCount; position++) { - if (valueIsNull == null || !valueIsNull[position]) { - columnVector.putFloat(position, intBitsToFloat(sliceInput.readInt())); - } else { - columnVector.putNull(position); - } - } - try { - PageDecoding.filedElementsAppended.set(columnVector, positionCount); - } catch (Exception e) { - throw new OmniDataException(e.getMessage()); - } - return Optional.of(columnVector); + return decodeSimple(sliceInput, DataTypes.FloatType, "float"); } @Override public Optional decodeDoubleArray(Optional type, SliceInput sliceInput) { - int positionCount = sliceInput.readInt(); - - boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.DoubleType); - for (int position = 0; position < positionCount; position++) { - if (valueIsNull == null || !valueIsNull[position]) { - columnVector.putDouble(position, longBitsToDouble(sliceInput.readLong())); - } else { - columnVector.putNull(position); - } - } - try { - PageDecoding.filedElementsAppended.set(columnVector, positionCount); - } catch (Exception e) { - throw new OmniDataException(e.getMessage()); - } - return Optional.of(columnVector); + return decodeSimple(sliceInput, DataTypes.DoubleType, "double"); } @Override @@ -232,17 +135,17 @@ public class PageDecoding extends AbstractDecoding decodeVariableWidth(Optional type, SliceInput sliceInput) { - int positionCount = sliceInput.readInt(); - + protected Optional decodeVariableWidthBase( + Optional type, + SliceInput sliceInput, + WritableColumnVector columnVector, + int positionCount) { int[] offsets = new int[positionCount + 1]; sliceInput.readBytes(Slices.wrappedIntArray(offsets), SIZE_OF_INT, Math.multiplyExact(positionCount, SIZE_OF_INT)); boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); int blockSize = sliceInput.readInt(); int curOffset = offsets[0]; int nextOffset; - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.StringType); for (int position = 0; position < positionCount; position++) { if (valueIsNull == null || !valueIsNull[position]) { nextOffset = offsets[position + 1]; @@ -250,7 +153,11 @@ public class PageDecoding extends AbstractDecoding decodeVariableWidth(Optional type, SliceInput sliceInput) { + int positionCount = sliceInput.readInt(); + return decodeVariableWidthBase(type, sliceInput, + createColumnVector(positionCount, DataTypes.StringType), positionCount); + } + @Override public Optional decodeDictionary(Optional type, SliceInput sliceInput) { throw new UnsupportedOperationException(); } - @Override - public Optional decodeRunLength(Optional type, SliceInput sliceInput) + protected Optional decodeRunLengthBase( + Optional type, + SliceInput sliceInput, + PageDeRunLength pageDeRunLength) throws InvocationTargetException, IllegalAccessException { int positionCount = sliceInput.readInt(); Optional resColumnVector = Optional.empty(); @@ -292,7 +208,6 @@ public class PageDecoding extends AbstractDecoding decodeRunLength(Optional type, SliceInput sliceInput) + throws InvocationTargetException, IllegalAccessException { + return decodeRunLengthBase(type, sliceInput, new PageDeRunLength()); + } + @Override public Optional decodeRow(Optional type, SliceInput sliceInput) { return Optional.empty(); @@ -311,123 +232,50 @@ public class PageDecoding extends AbstractDecoding decodeDate(Optional type, SliceInput sliceInput) { - int positionCount = sliceInput.readInt(); - - boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.DateType); - for (int position = 0; position < positionCount; position++) { - if (valueIsNull == null || !valueIsNull[position]) { - columnVector.putInt(position, sliceInput.readInt()); - } else { - columnVector.putNull(position); - } - } - try { - PageDecoding.filedElementsAppended.set(columnVector, positionCount); - } catch (Exception e) { - throw new OmniDataException(e.getMessage()); - } - return Optional.of(columnVector); + return decodeSimple(sliceInput, DataTypes.DateType, "date"); } @Override public Optional decodeLongToInt(Optional type, SliceInput sliceInput) { - int positionCount = sliceInput.readInt(); - - boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.IntegerType); - for (int position = 0; position < positionCount; position++) { - if (valueIsNull == null || !valueIsNull[position]) { - columnVector.putInt(position, (int) sliceInput.readLong()); - } else { - columnVector.putNull(position); - } - } - try { - PageDecoding.filedElementsAppended.set(columnVector, positionCount); - } catch (Exception e) { - throw new OmniDataException(e.getMessage()); - } - return Optional.of(columnVector); + return decodeSimple(sliceInput, DataTypes.IntegerType, "longToInt"); } @Override public Optional decodeLongToShort(Optional type, SliceInput sliceInput) { - int positionCount = sliceInput.readInt(); - - boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.ShortType); - for (int position = 0; position < positionCount; position++) { - if (valueIsNull == null || !valueIsNull[position]) { - columnVector.putShort(position, (short) sliceInput.readLong()); - } else { - columnVector.putNull(position); - } - } - try { - PageDecoding.filedElementsAppended.set(columnVector, positionCount); - } catch (Exception e) { - throw new OmniDataException(e.getMessage()); - } - return Optional.of(columnVector); + return decodeSimple(sliceInput, DataTypes.ShortType, "longToShort"); } @Override public Optional decodeLongToByte(Optional type, SliceInput sliceInput) { - int positionCount = sliceInput.readInt(); - - boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.ByteType); - for (int position = 0; position < positionCount; position++) { - if (valueIsNull == null || !valueIsNull[position]) { - columnVector.putByte(position, (byte) sliceInput.readLong()); - } else { - columnVector.putNull(position); - } - } - try { - PageDecoding.filedElementsAppended.set(columnVector, positionCount); - } catch (Exception e) { - throw new OmniDataException(e.getMessage()); - } - return Optional.of(columnVector); + return decodeSimple(sliceInput, DataTypes.ByteType, "longToByte"); } @Override public Optional decodeLongToFloat(Optional type, SliceInput sliceInput) { - int positionCount = sliceInput.readInt(); + return decodeSimple(sliceInput, DataTypes.FloatType, "longToFloat"); + } - boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, DataTypes.FloatType); - for (int position = 0; position < positionCount; position++) { - if (valueIsNull == null || !valueIsNull[position]) { - columnVector.putFloat(position, intBitsToFloat((int) sliceInput.readLong())); - } else { - columnVector.putNull(position); - } - } - try { - PageDecoding.filedElementsAppended.set(columnVector, positionCount); - } catch (Exception e) { - throw new OmniDataException(e.getMessage()); - } - return Optional.of(columnVector); + + protected WritableColumnVector createColumnVectorForDecimal(int positionCount, DecimalType decimalType) { + return createColumnVector(positionCount, decimalType); } @Override public Optional decodeDecimal(Optional type, SliceInput sliceInput, String decodeName) { int positionCount = sliceInput.readInt(); - boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); - if (!(type.get() instanceof DecimalDecodeType)) { - Optional.empty(); + DecimalDecodeType decimalDecodeType; + if ((type.get() instanceof DecimalDecodeType)) { + decimalDecodeType = (DecimalDecodeType) type.get(); + } else { + return Optional.empty(); } - DecimalDecodeType decimalDecodeType = (DecimalDecodeType) type.get(); int scale = decimalDecodeType.getScale(); int precision = decimalDecodeType.getPrecision(); - OnHeapColumnVector columnVector = new OnHeapColumnVector(positionCount, new DecimalType(precision, scale)); + WritableColumnVector columnVector = createColumnVectorForDecimal(positionCount, new DecimalType(precision, scale)); + boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); for (int position = 0; position < positionCount; position++) { if (valueIsNull == null || !valueIsNull[position]) { - BigInteger value = null; + BigInteger value; switch (decodeName) { case "LONG_ARRAY": value = BigInteger.valueOf(sliceInput.readLong()); @@ -454,28 +302,10 @@ public class PageDecoding extends AbstractDecoding decodeTimestamp(Optional type, SliceInput sliceInput) { - int positionCount = sliceInput.readInt(); - - boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); - WritableColumnVector columnVector = new OnHeapColumnVector(positionCount, TimestampType); - for (int position = 0; position < positionCount; position++) { - if (valueIsNull == null || !valueIsNull[position]) { - // milliseconds to microsecond - int rawOffset = TimeZone.getDefault().getRawOffset(); - columnVector.putLong(position, (sliceInput.readLong() - rawOffset) * 1000); - } else { - columnVector.putNull(position); - } - } - try { - PageDecoding.filedElementsAppended.set(columnVector, positionCount); - } catch (Exception e) { - throw new OmniDataException(e.getMessage()); - } - return Optional.of(columnVector); + return decodeSimple(sliceInput, DataTypes.TimestampType, "timestamp"); } - private Optional typeToDecodeName(Optional optType) { + Optional typeToDecodeName(Optional optType) { Class javaType = null; if (!optType.isPresent()) { return Optional.empty(); @@ -513,4 +343,91 @@ public class PageDecoding extends AbstractDecoding getWritableColumnVector(SliceInput sliceInput, int positionCount, + WritableColumnVector columnVector, String type) { + boolean[] valueIsNull = decodeNullBits(sliceInput, positionCount).orElse(null); + for (int position = 0; position < positionCount; position++) { + if (valueIsNull == null || !valueIsNull[position]) { + putData(columnVector, sliceInput, position, type); + } else { + columnVector.putNull(position); + } + } + try { + PageDecoding.filedElementsAppended.set(columnVector, positionCount); + } catch (Exception e) { + throw new OmniDataException(e.getMessage()); + } + return Optional.of(columnVector); + } + + private void putData(WritableColumnVector columnVector, SliceInput sliceInput, int position, String type) { + switch (type) { + case "byte": + columnVector.putByte(position, sliceInput.readByte()); + break; + case "boolean": + columnVector.putBoolean(position, sliceInput.readByte() != 0); + break; + case "int": + columnVector.putInt(position, sliceInput.readInt()); + break; + case "date": + int src = sliceInput.readInt(); + if ("ORC".equalsIgnoreCase(fileFormat)) { + src = RebaseDateTime.rebaseJulianToGregorianDays(src); + } + columnVector.putInt(position, src); + break; + case "short": + columnVector.putShort(position, sliceInput.readShort()); + break; + case "long": + columnVector.putLong(position, sliceInput.readLong()); + break; + case "float": + columnVector.putFloat(position, intBitsToFloat(sliceInput.readInt())); + break; + case "double": + columnVector.putDouble(position, longBitsToDouble(sliceInput.readLong())); + break; + case "longToInt": + columnVector.putInt(position, (int) sliceInput.readLong()); + break; + case "longToShort": + columnVector.putShort(position, (short) sliceInput.readLong()); + break; + case "longToByte": + columnVector.putByte(position, (byte) sliceInput.readLong()); + break; + case "longToFloat": + columnVector.putFloat(position, intBitsToFloat((int) sliceInput.readLong())); + break; + case "timestamp": + // milliseconds to microsecond + int rawOffset = TimeZone.getDefault().getRawOffset(); + columnVector.putLong(position, (sliceInput.readLong() - rawOffset) * 1000); + break; + default: + } + } + + protected Optional decodeSimple( + SliceInput sliceInput, + DataType dataType, + String dataTypeName) { + int positionCount = sliceInput.readInt(); + WritableColumnVector columnVector = createColumnVector(positionCount, dataType); + return getWritableColumnVector(sliceInput, positionCount, columnVector, dataTypeName); + } + + protected static WritableColumnVector createColumnVector(int positionCount, DataType dataType) { + boolean offHeapEnable = SQLConf.get().offHeapColumnVectorEnabled(); + if (offHeapEnable) { + return new OffHeapColumnVector(positionCount, dataType); + } else { + return new OnHeapColumnVector(positionCount, dataType); + } + } +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/PageDeserializer.java b/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/PageDeserializer.java index 062afec51d9ac89a0ec95e27498a18c30e0823fc..f4e8f5bd34c750f9cab4cc56e721c7cba972f742 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/PageDeserializer.java +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/com/huawei/boostkit/omnidata/spark/PageDeserializer.java @@ -30,6 +30,8 @@ import io.airlift.slice.SliceInput; import io.hetu.core.transport.execution.buffer.SerializedPage; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.Optional; @@ -39,15 +41,32 @@ import java.util.Optional; * @since 2021-03-30 */ public class PageDeserializer implements Deserializer { + private static final Logger LOG = LoggerFactory.getLogger(PageDeserializer.class); + private final PageDecoding decoding; private final DecodeType[] columnTypes; private final int[] columnOrders; - public PageDeserializer(DecodeType[] columnTypes, int[] columnOrders) { + /** + * initialize page deserializer + * + * @param columnTypes column type + * @param columnOrders column index + * @param isOperatorCombineEnabled whether combine is enabled + */ + public PageDeserializer(DecodeType[] columnTypes, + int[] columnOrders, + boolean isOperatorCombineEnabled, + String fileFormat) { this.columnTypes = columnTypes; - this.decoding = new PageDecoding(); + if (isOperatorCombineEnabled) { + this.decoding = new OperatorPageDecoding(fileFormat); + LOG.debug("OmniRuntime PushDown deserialization info: deserialize to OmniColumnVector"); + } else { + this.decoding = new PageDecoding(fileFormat); + } this.columnOrders = columnOrders; } @@ -56,6 +75,7 @@ public class PageDeserializer implements Deserializer { if (page.isEncrypted()) { throw new UnsupportedOperationException("unsupported compressed page."); } + SliceInput sliceInput = page.getSlice().getInput(); int numberOfBlocks = sliceInput.readInt(); int returnLength = columnOrders.length; @@ -88,5 +108,4 @@ public class PageDeserializer implements Deserializer { } return columnVectors; } - -} +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/ColumnInfo.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/ColumnInfo.java new file mode 100644 index 0000000000000000000000000000000000000000..fd135e77515c4043bf32b88e53f7e71481c5ffee --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/ColumnInfo.java @@ -0,0 +1,29 @@ +package org.apache.spark.sql; + +import io.prestosql.spi.type.Type; + +public class ColumnInfo { + private PrestoExpressionInfo expressionInfo; + + private Type prestoType; + + private int filterProjectionId; + + public ColumnInfo(PrestoExpressionInfo expressionInfo, Type prestoType, int filterProjectionId) { + this.expressionInfo = expressionInfo; + this.prestoType = prestoType; + this.filterProjectionId = filterProjectionId; + } + + public PrestoExpressionInfo getExpressionInfo() { + return expressionInfo; + } + + public Type getPrestoType() { + return prestoType; + } + + public int getFilterProjectionId() { + return filterProjectionId; + } +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/DataIoAdapter.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/DataIoAdapter.java index 57dc84a1beba09dfde351c5272f253e1d15f89a7..5f8f48b9420a763edc5e18bc5ac5bdf78ec56fc0 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/DataIoAdapter.java +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/DataIoAdapter.java @@ -26,8 +26,6 @@ import static io.prestosql.spi.type.BooleanType.BOOLEAN; import com.huawei.boostkit.omnidata.decode.type.DecodeType; import com.huawei.boostkit.omnidata.decode.type.LongDecodeType; import com.huawei.boostkit.omnidata.decode.type.RowDecodeType; -import com.huawei.boostkit.omnidata.exception.OmniDataException; -import com.huawei.boostkit.omnidata.exception.OmniErrorCode; import com.huawei.boostkit.omnidata.model.AggregationInfo; import com.huawei.boostkit.omnidata.model.Column; import com.huawei.boostkit.omnidata.model.Predicate; @@ -40,21 +38,25 @@ import com.huawei.boostkit.omnidata.spark.PageDeserializer; import com.google.common.collect.ImmutableMap; +import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.connector.QualifiedObjectName; import io.prestosql.spi.function.BuiltInFunctionHandle; import io.prestosql.spi.function.FunctionHandle; import io.prestosql.spi.function.FunctionKind; import io.prestosql.spi.function.Signature; +import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.relation.CallExpression; -import io.prestosql.spi.relation.ConstantExpression; +import io.prestosql.spi.relation.DomainTranslator; import io.prestosql.spi.relation.InputReferenceExpression; import io.prestosql.spi.relation.RowExpression; import io.prestosql.spi.relation.SpecialForm; -import io.prestosql.spi.type.BigintType; -import io.prestosql.spi.type.DoubleType; -import io.prestosql.spi.type.RowType; -import io.prestosql.spi.type.Type; -import io.prestosql.spi.type.TypeSignature; +import io.prestosql.spi.type.*; +import io.prestosql.sql.relational.RowExpressionDomainTranslator; +import org.apache.spark.TaskContext; +import org.apache.spark.sql.catalyst.expressions.BinaryComparison; +import org.apache.spark.sql.catalyst.expressions.EqualTo; +import org.apache.spark.sql.catalyst.expressions.UnaryExpression; +import org.apache.spark.sql.execution.ndp.NdpConf; import scala.collection.JavaConverters; import scala.collection.Seq; @@ -64,17 +66,9 @@ import org.apache.spark.sql.catalyst.expressions.And; import org.apache.spark.sql.catalyst.expressions.Attribute; import org.apache.spark.sql.catalyst.expressions.AttributeReference; import org.apache.spark.sql.catalyst.expressions.BinaryArithmetic; -import org.apache.spark.sql.catalyst.expressions.Cast; import org.apache.spark.sql.catalyst.expressions.Divide; -import org.apache.spark.sql.catalyst.expressions.EqualTo; import org.apache.spark.sql.catalyst.expressions.Expression; -import org.apache.spark.sql.catalyst.expressions.GreaterThan; -import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; import org.apache.spark.sql.catalyst.expressions.In; -import org.apache.spark.sql.catalyst.expressions.IsNotNull; -import org.apache.spark.sql.catalyst.expressions.IsNull; -import org.apache.spark.sql.catalyst.expressions.LessThan; -import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; import org.apache.spark.sql.catalyst.expressions.Literal; import org.apache.spark.sql.catalyst.expressions.Multiply; import org.apache.spark.sql.catalyst.expressions.NamedExpression; @@ -83,7 +77,6 @@ import org.apache.spark.sql.catalyst.expressions.Or; import org.apache.spark.sql.catalyst.expressions.Remainder; import org.apache.spark.sql.catalyst.expressions.Subtract; import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction; -import org.apache.spark.sql.execution.datasources.PartitionedFile; import org.apache.spark.sql.execution.ndp.AggExeInfo; import org.apache.spark.sql.execution.ndp.FilterExeInfo; import org.apache.spark.sql.execution.ndp.PushDownInfo; @@ -95,7 +88,6 @@ import org.slf4j.LoggerFactory; import java.net.InetAddress; import java.net.UnknownHostException; -import java.sql.Date; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -116,6 +108,8 @@ import java.util.Set; public class DataIoAdapter { private int TASK_FAILED_TIMES = 4; + private int MAX_PAGE_SIZE_IN_BYTES = 1048576; + private List omnidataTypes = new ArrayList<>(); private List omnidataColumns = new ArrayList<>(); @@ -126,7 +120,7 @@ public class DataIoAdapter { private boolean hasNextPage = false; - private DataReaderImpl orcDataReader = null; + private DataReaderImpl orcDataReader = null; private List columnTypesList = new ArrayList<>(); @@ -156,6 +150,21 @@ public class DataIoAdapter { private static final Logger LOG = LoggerFactory.getLogger(DataIoAdapter.class); + private boolean isPushDownAgg = true; + + private boolean isOperatorCombineEnabled; + + private String omniGroupId; + + private final Thread shutdownHook = new Thread() { + public void run() { + if (orcDataReader != null && omniGroupId != null) { + LOG.info("force close task in {}", omniGroupId); + orcDataReader.forceClose(omniGroupId); + } + } + }; + /** * Contact with Omni-Data-Server * @@ -164,6 +173,7 @@ public class DataIoAdapter { * @param partitionColumn partition column * @param filterOutPut filter schema * @param pushDownOperators push down expressions + * @param domains domain map * @return WritableColumnVector data result info * @throws TaskExecutionException connect to omni-data-server failed exception * @notice 3rd parties api throws Exception, function has to catch basic Exception @@ -173,81 +183,115 @@ public class DataIoAdapter { Seq sparkOutPut, Seq partitionColumn, Seq filterOutPut, - PushDownInfo pushDownOperators) throws TaskExecutionException, UnknownHostException { + PushDownInfo pushDownOperators, + ImmutableMap domains, + Boolean isColumnVector, + String omniGroupId) throws TaskExecutionException, UnknownHostException { // initCandidates initCandidates(pageCandidate, filterOutPut); - // create AggregationInfo - // init agg candidates - List partitionColumnBatch = JavaConverters.seqAsJavaList(partitionColumn); - for (Attribute attribute : partitionColumnBatch) { - partitionColumnName.add(attribute.name()); - } - List aggExecutionList = - JavaConverters.seqAsJavaList(pushDownOperators.aggExecutions()); - if (aggExecutionList.size() == 0) { + // add partition column + JavaConverters.seqAsJavaList(partitionColumn).forEach(a -> partitionColumnName.add(a.name())); + + // init column info + if (pushDownOperators.aggExecutions().size() == 0) { + isPushDownAgg = false; + // deal with join has a project node with empty output + if (sparkOutPut.isEmpty()) { + sparkOutPut = filterOutPut; + } initColumnInfo(sparkOutPut); } - DataSource dataSource = initDataSource(pageCandidate); - RowExpression rowExpression = initFilter(pushDownOperators.filterExecutions()); - Optional prestoFilter = rowExpression == null ? - Optional.empty() : Optional.of(rowExpression); - Optional aggregations = - initAggAndGroupInfo(aggExecutionList); - // create limitLong + + // create filter + Optional filterRowExpression = initFilter(pushDownOperators.filterExecutions()); + + // create agg + Optional aggregations = initAggAndGroupInfo(pushDownOperators.aggExecutions()); + + // create limit OptionalLong limitLong = NdpUtils.convertLimitExeInfo(pushDownOperators.limitExecution()); + // create TaskSource + DataSource dataSource = initDataSource(pageCandidate); + + this.omniGroupId = omniGroupId; Predicate predicate = new Predicate( - omnidataTypes, omnidataColumns, prestoFilter, omnidataProjections, - ImmutableMap.of(), ImmutableMap.of(), aggregations, limitLong); - TaskSource taskSource = new TaskSource(dataSource, predicate, 1048576); - PageDeserializer deserializer = initPageDeserializer(); - WritableColumnVector[] page = null; + omnidataTypes, omnidataColumns, filterRowExpression, omnidataProjections, + domains, ImmutableMap.of(), aggregations, limitLong); + TaskSource taskSource = new TaskSource(dataSource, predicate, MAX_PAGE_SIZE_IN_BYTES, omniGroupId); + + // create deserializer + this.isOperatorCombineEnabled = + pageCandidate.isOperatorCombineEnabled() && NdpUtils.checkOmniOpColumns(omnidataColumns) + && isColumnVector; + PageDeserializer deserializer = initPageDeserializer(pageCandidate.getFileFormat()); + + // get available host + List pushDownHostList = new ArrayList<>(); + String[] pushDownHostArray; + if (pageCandidate.getpushDownHosts().length() == 0) { + Optional availablePushDownHost = getRandomAvailablePushDownHost(new String[]{}, + JavaConverters.mapAsJavaMap(pushDownOperators.fpuHosts())); + availablePushDownHost.ifPresent(pushDownHostList::add); + pushDownHostArray = pushDownHostList.toArray(new String[]{}); + } else { + pushDownHostArray = pageCandidate.getpushDownHosts().split(","); + pushDownHostList = new ArrayList<>(Arrays.asList(pushDownHostArray)); + Optional availablePushDownHost = getRandomAvailablePushDownHost(pushDownHostArray, + JavaConverters.mapAsJavaMap(pushDownOperators.fpuHosts())); + availablePushDownHost.ifPresent(pushDownHostList::add); + } + return getIterator(pushDownHostList.iterator(), taskSource, pushDownHostArray, deserializer, + pushDownHostList.size()); + } + + private Iterator getIterator(Iterator pushDownHosts, TaskSource taskSource, + String[] pushDownHostArray, PageDeserializer deserializer, + int pushDownHostsSize) throws UnknownHostException { + int randomIndex = (int) (Math.random() * pushDownHostArray.length); int failedTimes = 0; - String[] sdiHostArray = pageCandidate.getSdiHosts().split(","); - int randomIndex = (int) (Math.random() * sdiHostArray.length); - List sdiHostList = new ArrayList<>(Arrays.asList(sdiHostArray)); - Optional availableSdiHost = getRandomAvailableSdiHost(sdiHostArray, - JavaConverters.mapAsJavaMap(pushDownOperators.fpuHosts())); - availableSdiHost.ifPresent(sdiHostList::add); - Iterator sdiHosts = sdiHostList.iterator(); - Set sdiHostSet = new HashSet<>(); - sdiHostSet.add(sdiHostArray[randomIndex]); - while (sdiHosts.hasNext()) { - String sdiHost; + WritableColumnVector[] page = null; + Set pushDownHostSet = new HashSet<>(); + pushDownHostSet.add(pushDownHostArray[randomIndex]); + while (pushDownHosts.hasNext()) { + String pushDownHost; if (failedTimes == 0) { - sdiHost = sdiHostArray[randomIndex]; + pushDownHost = pushDownHostArray[randomIndex]; } else { - sdiHost = sdiHosts.next(); - if (sdiHostSet.contains(sdiHost)) { + pushDownHost = pushDownHosts.next(); + if (pushDownHostSet.contains(pushDownHost)) { continue; } } - String ipAddress = InetAddress.getByName(sdiHost).getHostAddress(); + String ipAddress = InetAddress.getByName(pushDownHost).getHostAddress(); Properties properties = new Properties(); properties.put("omnidata.client.target.list", ipAddress); properties.put("omnidata.client.task.timeout", taskTimeout); - LOG.info("Push down node info: [hostname :{} ,ip :{}]", sdiHost, ipAddress); + LOG.info("Push down node info: [hostname :{} ,ip :{}]", pushDownHost, ipAddress); try { - orcDataReader = new DataReaderImpl( + Runtime.getRuntime().addShutdownHook(shutdownHook); + orcDataReader = new DataReaderImpl<>( properties, taskSource, deserializer); hasNextPage = true; - page = (WritableColumnVector[]) orcDataReader.getNextPageBlocking(); + page = orcDataReader.getNextPageBlocking(); if (orcDataReader.isFinished()) { orcDataReader.close(); hasNextPage = false; + Runtime.getRuntime().removeShutdownHook(shutdownHook); } break; } catch (Exception e) { - LOG.warn("Push down failed node info [hostname :{} ,ip :{}]", sdiHost, ipAddress, e); + LOG.warn("Push down failed node info [hostname :{} ,ip :{}]", pushDownHost, ipAddress, e); ++failedTimes; if (orcDataReader != null) { orcDataReader.close(); hasNextPage = false; + Runtime.getRuntime().removeShutdownHook(shutdownHook); } } } - int retryTime = Math.min(TASK_FAILED_TIMES, sdiHostList.size()); + int retryTime = Math.min(TASK_FAILED_TIMES, pushDownHostsSize); if (failedTimes >= retryTime) { LOG.warn("No Omni-data-server to Connect, Task has tried {} times.", retryTime); throw new TaskExecutionException("No Omni-data-server to Connect"); @@ -257,62 +301,72 @@ public class DataIoAdapter { return l.iterator(); } - private Optional getRandomAvailableSdiHost(String[] sdiHostArray, Map fpuHosts) { - List existingHosts = Arrays.asList(sdiHostArray); - List allHosts = new ArrayList<>(fpuHosts.values()); + public void close() { + if (orcDataReader != null) { + orcDataReader.close(); + hasNextPage = false; + } + } + + private Optional getRandomAvailablePushDownHost(String[] pushDownHostArray, + Map fpuHosts) { + List existingHosts = Arrays.asList(pushDownHostArray); + List allHosts = new ArrayList<>(fpuHosts.keySet()); allHosts.removeAll(existingHosts); if (allHosts.size() > 0) { - LOG.info("Add another available host: " + allHosts.get(0)); - return Optional.of(allHosts.get(0)); + int randomIndex = (int) (Math.random() * allHosts.size()); + LOG.info("Add another available host: " + allHosts.get(randomIndex)); + return Optional.of(allHosts.get(randomIndex)); } else { return Optional.empty(); } } public boolean hasNextIterator(List pageList, PageToColumnar pageToColumnarClass, - PartitionedFile partitionFile, boolean isVectorizedReader) - throws Exception { + boolean isVectorizedReader, Seq sparkOutPut, String orcImpl) { if (!hasNextPage) { return false; } - WritableColumnVector[] page = (WritableColumnVector[]) orcDataReader.getNextPageBlocking(); - if (orcDataReader.isFinished()) { - orcDataReader.close(); - return false; + WritableColumnVector[] page = null; + try { + page = orcDataReader.getNextPageBlocking(); + if (orcDataReader.isFinished()) { + orcDataReader.close(); + hasNextPage = false; + Runtime.getRuntime().removeShutdownHook(shutdownHook); + return false; + } + } catch (Exception e) { + LOG.error("Push down failed", e); + if (orcDataReader != null) { + orcDataReader.close(); + hasNextPage = false; + Runtime.getRuntime().removeShutdownHook(shutdownHook); + } + throw e; } List l = new ArrayList<>(); l.add(page); pageList.addAll(pageToColumnarClass - .transPageToColumnar(l.iterator(), isVectorizedReader)); + .transPageToColumnar(l.iterator(), isVectorizedReader, isOperatorCombineEnabled, sparkOutPut, orcImpl)); return true; } private void initCandidates(PageCandidate pageCandidate, Seq filterOutPut) { - omnidataTypes.clear(); - omnidataColumns.clear(); - omnidataProjections.clear(); - fieldMap.clear(); - columnNameSet.clear(); - columnTypesList.clear(); - columnOrdersList.clear(); - filterTypesList.clear(); - filterOrdersList.clear(); - partitionColumnName.clear(); - columnNameMap.clear(); - columnOrder = 0; + initCandidatesBeforeDomain(filterOutPut); filePath = pageCandidate.getFilePath(); columnOffset = pageCandidate.getColumnOffset(); - listAtt = JavaConverters.seqAsJavaList(filterOutPut); TASK_FAILED_TIMES = pageCandidate.getMaxFailedTimes(); taskTimeout = pageCandidate.getTaskTimeout(); } - private RowExpression extractNamedExpression(Expression namedExpression) { - Type prestoType = NdpUtils.transOlkDataType(namedExpression.dataType(), false); + private RowExpression extractNamedExpression(NamedExpression namedExpression) { + Type prestoType = NdpUtils.transOlkDataType(((Expression) namedExpression).dataType(), namedExpression, + false); int aggProjectionId; - String aggColumnName = namedExpression.toString().split("#")[0].toLowerCase(Locale.ENGLISH); + String aggColumnName = namedExpression.name(); columnOrdersList.add(columnOrder++); - columnTypesList.add(NdpUtils.transDataIoDataType(namedExpression.dataType())); + columnTypesList.add(NdpUtils.transDecodeType(((Expression) namedExpression).dataType())); if (null != fieldMap.get(aggColumnName)) { aggProjectionId = fieldMap.get(aggColumnName); @@ -327,9 +381,7 @@ public class DataIoAdapter { omnidataColumns.add(new Column(columnId, aggColumnName, prestoType, isPartitionKey, partitionValue)); columnNameSet.add(aggColumnName); - if (null == columnNameMap.get(aggColumnName)) { - columnNameMap.put(aggColumnName, columnNameMap.size()); - } + columnNameMap.computeIfAbsent(aggColumnName, k -> columnNameMap.size()); omnidataProjections.add(new InputReferenceExpression(aggProjectionId, prestoType)); } @@ -398,9 +450,9 @@ public class DataIoAdapter { String operatorName, Type prestoType) { List arguments = new ArrayList<>(); Type leftPrestoType = NdpUtils.transOlkDataType( - expression.left().dataType(), false); + expression.left().dataType(), expression.left(), false); Type rightPrestoType = NdpUtils.transOlkDataType( - expression.right().dataType(), false); + expression.right().dataType(), expression.right(), false); FunctionHandle functionHandle = new BuiltInFunctionHandle( new Signature(QualifiedObjectName.valueOfDefaultFunction("$operator$" + operatorName), SCALAR, prestoType.getTypeSignature(), @@ -413,7 +465,7 @@ public class DataIoAdapter { } private RowExpression createAggProjection(Expression expression) { - Type prestoType = NdpUtils.transOlkDataType(expression.dataType(), false); + Type prestoType = NdpUtils.transOlkDataType(expression.dataType(), expression, false); AggExpressionType aggExpressionType = AggExpressionType .valueOf(expression.getClass().getSimpleName()); switch (aggExpressionType) { @@ -426,11 +478,9 @@ public class DataIoAdapter { case Divide: return createAggBinCall((Divide) expression, "Divide", prestoType); case Remainder: - return createAggBinCall((Remainder) expression, "Remainder", prestoType); + return createAggBinCall((Remainder) expression, "Modulus", prestoType); case Literal: - Object value = NdpUtils.transData( - expression.dataType().toString(), expression.toString()); - return new ConstantExpression(value, prestoType); + return NdpUtils.transConstantExpression(expression.toString(), prestoType); case AttributeReference: String aggColumnName = expression.toString().split("#")[0].toLowerCase(Locale.ENGLISH); int field; @@ -477,8 +527,9 @@ public class DataIoAdapter { omnidataProjections.add(createAggProjection(expression)); int projectionId = fieldMap.size(); fieldMap.put(aggregateFunctionName, projectionId); - if (aggregateFunctionType.equals(AggregateFunctionType.Count)) { - prestoType = NdpUtils.transOlkDataType(expression.dataType(), false); + if (aggregateFunctionType.equals(AggregateFunctionType.Count) + || aggregateFunctionType.equals(AggregateFunctionType.Average)) { + prestoType = NdpUtils.transOlkDataType(expression.dataType(), expression, false); } omnidataTypes.add(prestoType); break; @@ -523,7 +574,8 @@ public class DataIoAdapter { LessThanOrEqual, In, HiveSimpleUDF, - IsNull + IsNull, + AttributeReference } private Optional createAggregationInfo( @@ -533,7 +585,7 @@ public class DataIoAdapter { Map aggregationMap = new LinkedHashMap<>(); boolean isEmpty = true; for (NamedExpression namedExpression : namedExpressions) { - RowExpression groupingKey = extractNamedExpression((Expression) namedExpression); + RowExpression groupingKey = extractNamedExpression(namedExpression); groupingKeys.add(groupingKey); isEmpty = false; } @@ -545,28 +597,6 @@ public class DataIoAdapter { new AggregationInfo(aggregationMap, groupingKeys)); } - private Optional extractAggAndGroupExpression( - List aggExecutionList) { - Optional resAggregationInfo = Optional.empty(); - for (AggExeInfo aggExeInfo : aggExecutionList) { - List aggregateExpressions = JavaConverters.seqAsJavaList( - aggExeInfo.aggregateExpressions()); - List namedExpressions = JavaConverters.seqAsJavaList( - aggExeInfo.groupingExpressions()); - resAggregationInfo = createAggregationInfo(aggregateExpressions, namedExpressions); - } - return resAggregationInfo; - } - - private RowExpression extractFilterExpression(Seq filterExecution) { - List filterExecutionList = JavaConverters.seqAsJavaList(filterExecution); - RowExpression resRowExpression = null; - for (FilterExeInfo filterExeInfo : filterExecutionList) { - resRowExpression = reverseExpressionTree(filterExeInfo.filter()); - } - return resRowExpression; - } - private RowExpression reverseExpressionTree(Expression filterExpression) { RowExpression resRowExpression = null; if (filterExpression == null) { @@ -595,170 +625,152 @@ public class DataIoAdapter { private RowExpression getExpression(Expression filterExpression) { RowExpression resRowExpression = null; - List rightExpressions = new ArrayList<>(); - ExpressionOperator expressionOperType = - ExpressionOperator.valueOf(filterExpression.getClass().getSimpleName()); - Expression left; - Expression right; - String operatorName; + ExpressionOperator expressionOperType = ExpressionOperator.valueOf(filterExpression.getClass().getSimpleName()); switch (expressionOperType) { case Or: case And: return reverseExpressionTree(filterExpression); - case Not: - Signature notSignature = new Signature( - QualifiedObjectName.valueOfDefaultFunction("not"), - FunctionKind.SCALAR, new TypeSignature("boolean"), - new TypeSignature("boolean")); - RowExpression tempRowExpression = getExpression(((Not) filterExpression).child()); - List notArguments = new ArrayList<>(); - notArguments.add(tempRowExpression); - return new CallExpression("not", new BuiltInFunctionHandle(notSignature), - BOOLEAN, notArguments, Optional.empty()); case EqualTo: - if (((EqualTo) filterExpression).left() instanceof Literal) { - rightExpressions.add(((EqualTo) filterExpression).left()); - left = ((EqualTo) filterExpression).right(); - } else { - rightExpressions.add(((EqualTo) filterExpression).right()); - left = ((EqualTo) filterExpression).left(); - } - return getRowExpression(left, - "equal", rightExpressions); - case IsNotNull: - Signature isnullSignature = new Signature( - QualifiedObjectName.valueOfDefaultFunction("not"), - FunctionKind.SCALAR, new TypeSignature("boolean"), - new TypeSignature("boolean")); - RowExpression isnullRowExpression = - getRowExpression(((IsNotNull) filterExpression).child(), - "is_null", null); - List isnullArguments = new ArrayList<>(); - isnullArguments.add(isnullRowExpression); - return new CallExpression("not", new BuiltInFunctionHandle(isnullSignature), - BOOLEAN, isnullArguments, Optional.empty()); - case IsNull: - return getRowExpression(((IsNull) filterExpression).child(), - "is_null", null); + return getRowExpression(flatBinaryExpression(filterExpression), "EQUAL"); case LessThan: - if (((LessThan) filterExpression).left() instanceof Literal) { - rightExpressions.add(((LessThan) filterExpression).left()); - left = ((LessThan) filterExpression).right(); - operatorName = "greater_than"; - } else { - rightExpressions.add(((LessThan) filterExpression).right()); - left = ((LessThan) filterExpression).left(); - operatorName = "less_than"; - } - return getRowExpression(left, - operatorName, rightExpressions); + return getRowExpression(flatBinaryExpression(filterExpression), "LESS_THAN"); + case LessThanOrEqual: + return getRowExpression(flatBinaryExpression(filterExpression), "LESS_THAN_OR_EQUAL"); case GreaterThan: - if (((GreaterThan) filterExpression).left() instanceof Literal) { - rightExpressions.add(((GreaterThan) filterExpression).left()); - left = ((GreaterThan) filterExpression).right(); - operatorName = "less_than"; - } else { - rightExpressions.add(((GreaterThan) filterExpression).right()); - left = ((GreaterThan) filterExpression).left(); - operatorName = "greater_than"; - } - return getRowExpression(left, - operatorName, rightExpressions); + return getRowExpression(flatBinaryExpression(filterExpression), "GREATER_THAN"); case GreaterThanOrEqual: - if (((GreaterThanOrEqual) filterExpression).left() instanceof Literal) { - rightExpressions.add(((GreaterThanOrEqual) filterExpression).left()); - left = ((GreaterThanOrEqual) filterExpression).right(); - operatorName = "less_than_or_equal"; + return getRowExpression(flatBinaryExpression(filterExpression), "GREATER_THAN_OR_EQUAL"); + case IsNull: + return getRowExpression(flatUnaryExpression(filterExpression), "IS_NULL"); + case AttributeReference: + Type type = NdpUtils.transOlkDataType(filterExpression.dataType(), filterExpression, false); + return new InputReferenceExpression(putFilterValue(filterExpression, type), type); + case HiveSimpleUDF: + if (filterExpression instanceof HiveSimpleUDF) { + return getRowExpression(Arrays.asList(filterExpression), ((HiveSimpleUDF) filterExpression).name()); } else { - rightExpressions.add(((GreaterThanOrEqual) filterExpression).right()); - left = ((GreaterThanOrEqual) filterExpression).left(); - operatorName = "greater_than_or_equal"; + return resRowExpression; } - return getRowExpression(left, - operatorName, rightExpressions); - case LessThanOrEqual: - if (((LessThanOrEqual) filterExpression).left() instanceof Literal) { - rightExpressions.add(((LessThanOrEqual) filterExpression).left()); - left = ((LessThanOrEqual) filterExpression).right(); - operatorName = "greater_than_or_equal"; + case Not: + case IsNotNull: + // get child RowExpression + RowExpression childRowExpression; + if (expressionOperType == ExpressionOperator.IsNotNull) { + childRowExpression = getRowExpression(flatUnaryExpression(filterExpression), "IS_NULL"); } else { - rightExpressions.add(((LessThanOrEqual) filterExpression).right()); - left = ((LessThanOrEqual) filterExpression).left(); - operatorName = "less_than_or_equal"; + if (filterExpression instanceof Not) { + Expression childExpression = ((Not) filterExpression).child(); + // use "NOT_EQUAL" to adapt to omniOperator + if (childExpression instanceof EqualTo) { + return getRowExpression(flatBinaryExpression(childExpression), "NOT_EQUAL"); + } + childRowExpression = getExpression(childExpression); + } else { + return resRowExpression; + } } - return getRowExpression(left, - operatorName, rightExpressions); + // get not(child) RowExpression + List notArguments = new ArrayList<>(); + notArguments.add(childRowExpression); + return NdpFilterUtils.generateRowExpression("NOT", notArguments, new TypeSignature("boolean")); case In: - List rightExpression = - JavaConverters.seqAsJavaList(((In) filterExpression).list()); - return getRowExpression(((In) filterExpression).value(), "in", rightExpression); - case HiveSimpleUDF: - return getRowExpression(filterExpression, - ((HiveSimpleUDF) filterExpression).name(), rightExpressions); + if (filterExpression instanceof In) { + // flat "in" expression + In in = (In) filterExpression; + List rightExpressions = JavaConverters.seqAsJavaList(in.list()); + List expressions = new ArrayList<>(rightExpressions.size() + 1); + expressions.add(in.value()); + expressions.addAll(rightExpressions); + // get RowExpression + return getRowExpression(expressions, "IN"); + } else { + return resRowExpression; + } default: return resRowExpression; } } - private RowExpression getRowExpression(Expression leftExpression, String operatorName, - List rightExpression) { + private List flatBinaryExpression(Expression expression) { + if (expression instanceof BinaryComparison) { + BinaryComparison binaryComparison = (BinaryComparison) expression; + return Arrays.asList(binaryComparison.left(), binaryComparison.right()); + } else { + return new ArrayList<>(); + } + } + + private List flatUnaryExpression(Expression expression) { + if (expression instanceof UnaryExpression) { + UnaryExpression unaryExpression = (UnaryExpression) expression; + return Arrays.asList(unaryExpression.child()); + } else { + return new ArrayList<>(); + } + } + + private RowExpression getRowExpression(List expressions, String operatorName) { + RowExpression rowExpression = null; + // get filter type + Type filterType = null; + for (Expression expression : expressions) { + if (!(expression instanceof Literal)) { + if (expression instanceof AttributeReference) { + filterType = NdpUtils.transOlkDataType(expression.dataType(), expression, false); + } else { + filterType = NdpUtils.transOlkDataType(expression.dataType(), true); + } + break; + } + } + if (filterType == null) { + return rowExpression; + } + // create arguments + List arguments = new ArrayList<>(); + for (Expression expression : expressions) { + if (expression instanceof Literal) { + arguments.add(NdpUtils.transConstantExpression(expression.toString(), filterType)); + } else { + arguments.add(NdpFilterUtils.createRowExpressionForColumn(getColumnInfo(expression))); + } + } + return NdpFilterUtils.generateRowExpression(operatorName, arguments, filterType.getTypeSignature()); + } + + private ColumnInfo getColumnInfo(Expression expression) { PrestoExpressionInfo expressionInfo = new PrestoExpressionInfo(); Type prestoType; int filterProjectionId; - // deal with left expression only UDF and Attribute - if (leftExpression instanceof AttributeReference) { - prestoType = NdpUtils.transOlkDataType(leftExpression.dataType(), false); - filterProjectionId = putFilterValue(leftExpression, prestoType); - } else if (leftExpression instanceof Cast && operatorName.equals("in")) { - prestoType = NdpUtils.transOlkDataType(((Cast) leftExpression).child().dataType(), false); - filterProjectionId = putFilterValue(((Cast) leftExpression).child(), prestoType); - } else { - if (leftExpression instanceof HiveSimpleUDF) { - for (int i = 0; i < leftExpression.children().length(); i++) { - Expression childExpr = leftExpression.children().apply(i); - if (!(childExpr instanceof Literal)) { - putFilterValue(childExpr, NdpUtils.transOlkDataType(childExpr.dataType(), false)); - } + // deal with expression only UDF and Attribute + if (expression instanceof AttributeReference) { + prestoType = NdpUtils.transOlkDataType(expression.dataType(), expression, false); + filterProjectionId = putFilterValue(expression, prestoType); + } else if (expression instanceof HiveSimpleUDF) { + for (int i = 0; i < expression.children().length(); i++) { + Expression childExpr = expression.children().apply(i); + if (childExpr instanceof Attribute) { + putFilterValue(childExpr, NdpUtils.transOlkDataType(childExpr.dataType(), + childExpr, false)); + } else if (!(childExpr instanceof Literal)) { + putFilterValue(childExpr, NdpUtils.transOlkDataType(childExpr.dataType(), false)); } - ndpUdfExpressions.createNdpUdf(leftExpression, expressionInfo, fieldMap); - } else { - ndpUdfExpressions.createNdpUdf(leftExpression, expressionInfo, fieldMap); - putFilterValue(expressionInfo.getChildExpression(), expressionInfo.getFieldDataType()); } + ndpUdfExpressions.createNdpUdf(expression, expressionInfo, fieldMap); prestoType = expressionInfo.getReturnType(); filterProjectionId = expressionInfo.getProjectionId(); - } - // deal with right expression - List argumentValues = new ArrayList<>(); - List multiArguments = new ArrayList<>(); - int rightProjectionId = -1; - RowExpression rowExpression; - if (rightExpression != null && rightExpression.size() > 0 && - rightExpression.get(0) instanceof AttributeReference) { - rightProjectionId = putFilterValue(rightExpression.get(0), prestoType); - multiArguments.add(new InputReferenceExpression(filterProjectionId, prestoType)); - multiArguments.add(new InputReferenceExpression(rightProjectionId, prestoType)); - rowExpression = NdpFilterUtils.generateRowExpression( - operatorName, expressionInfo, prestoType, filterProjectionId, - null, multiArguments, "multy_columns"); } else { - // get right value - if (NdpUtils.isInDateExpression(leftExpression, operatorName)) { - argumentValues = getDateValue(rightExpression); - } else { - argumentValues = getValue(rightExpression, operatorName, - leftExpression.dataType().toString()); - } - rowExpression = NdpFilterUtils.generateRowExpression( - operatorName, expressionInfo, prestoType, filterProjectionId, - argumentValues, null, operatorName); + ndpUdfExpressions.createNdpUdf(expression, expressionInfo, fieldMap); + putFilterValue(expressionInfo.getChildExpression(), expressionInfo.getFieldDataType()); + prestoType = expressionInfo.getReturnType(); + filterProjectionId = expressionInfo.getProjectionId(); } - return rowExpression; + return new ColumnInfo(expressionInfo, prestoType, filterProjectionId); } - // column projection赋值 + // column projection private int putFilterValue(Expression valueExpression, Type prestoType) { - // Filter赋值 + // set filter int columnId = NdpUtils.getColumnId(valueExpression.toString()) - columnOffset; String filterColumnName = valueExpression.toString().split("#")[0].toLowerCase(Locale.ENGLISH); if (null != fieldMap.get(filterColumnName)) { @@ -767,72 +779,32 @@ public class DataIoAdapter { boolean isPartitionKey = partitionColumnName.contains(filterColumnName); int filterProjectionId = fieldMap.size(); fieldMap.put(filterColumnName, filterProjectionId); - filterTypesList.add(NdpUtils.transDataIoDataType(valueExpression.dataType())); - filterOrdersList.add(filterProjectionId); + String partitionValue = NdpUtils.getPartitionValue(filePath, filterColumnName); columnNameSet.add(filterColumnName); - omnidataProjections.add(new InputReferenceExpression(filterProjectionId, prestoType)); omnidataColumns.add(new Column(columnId, filterColumnName, prestoType, isPartitionKey, partitionValue)); - omnidataTypes.add(prestoType); + if (isPushDownAgg) { + filterTypesList.add(NdpUtils.transDecodeType(valueExpression.dataType())); + filterOrdersList.add(filterProjectionId); + omnidataProjections.add(new InputReferenceExpression(filterProjectionId, prestoType)); + omnidataTypes.add(prestoType); + } if (null == columnNameMap.get(filterColumnName)) { columnNameMap.put(filterColumnName, columnNameMap.size()); } return filterProjectionId; } - // for date parse - private List getDateValue(List rightExpression) { - long DAY_TO_MILL_SECS = 24L * 3600L * 1000L; - List dateTimes = new ArrayList<>(); - for (Expression rExpression : rightExpression) { - String dateStr = rExpression.toString(); - if (NdpUtils.isValidDateFormat(dateStr)) { - String[] dateStrArray = dateStr.split("-"); - int year = Integer.parseInt(dateStrArray[0]) - 1900; - int month = Integer.parseInt(dateStrArray[1]) - 1; - int day = Integer.parseInt(dateStrArray[2]); - Date date = new Date(year, month, day); - dateTimes.add(String.valueOf((date.getTime() - date.getTimezoneOffset() * 60000L) / DAY_TO_MILL_SECS)); - } else { - throw new UnsupportedOperationException("decode date failed: " + dateStr); - } - } - return dateTimes; - } - - private List getValue(List rightExpression, - String operatorName, - String sparkType) { - Object objectValue; - List argumentValues = new ArrayList<>(); - if (null == rightExpression || rightExpression.size() == 0) { - return argumentValues; - } - switch (operatorName.toLowerCase(Locale.ENGLISH)) { - case "in": - List inValue = new ArrayList<>(); - for (Expression rExpression : rightExpression) { - inValue.add(rExpression.toString()); - } - argumentValues = inValue; - break; - default: - argumentValues.add(rightExpression.get(0).toString()); - break; - } - return argumentValues; - } - - private PageDeserializer initPageDeserializer() { + private PageDeserializer initPageDeserializer(String fileFormat) { DecodeType[] columnTypes = columnTypesList.toArray(new DecodeType[0]); int[] columnOrders = columnOrdersList.stream().mapToInt(Integer::intValue).toArray(); DecodeType[] filterTypes = filterTypesList.toArray(new DecodeType[0]); int[] filterOrders = filterOrdersList.stream().mapToInt(Integer::intValue).toArray(); if (columnTypes.length == 0) { - return new PageDeserializer(filterTypes, filterOrders); + return new PageDeserializer(filterTypes, filterOrders, isOperatorCombineEnabled, fileFormat); } else { - return new PageDeserializer(columnTypes, columnOrders); + return new PageDeserializer(columnTypes, columnOrders, isOperatorCombineEnabled, fileFormat); } } @@ -852,14 +824,28 @@ public class DataIoAdapter { return dataSource; } - private RowExpression initFilter(Seq filterExecutions) { - return extractFilterExpression(filterExecutions); + public Optional initFilter(Seq filterExecutions) { + List filterExecutionList = JavaConverters.seqAsJavaList(filterExecutions); + Optional resRowExpression = Optional.empty(); + for (FilterExeInfo filterExeInfo : filterExecutionList) { + resRowExpression = Optional.ofNullable(reverseExpressionTree(filterExeInfo.filter())); + } + return resRowExpression; } private Optional initAggAndGroupInfo( - List aggExecutionList) { - // create AggregationInfo - return extractAggAndGroupExpression(aggExecutionList); + Seq aggExeInfoSeq) { + List aggExecutionList = + JavaConverters.seqAsJavaList(aggExeInfoSeq); + Optional resAggregationInfo = Optional.empty(); + for (AggExeInfo aggExeInfo : aggExecutionList) { + List aggregateExpressions = JavaConverters.seqAsJavaList( + aggExeInfo.aggregateExpressions()); + List namedExpressions = JavaConverters.seqAsJavaList( + aggExeInfo.groupingExpressions()); + resAggregationInfo = createAggregationInfo(aggregateExpressions, namedExpressions); + } + return resAggregationInfo; } private void initColumnInfo(Seq sparkOutPut) { @@ -873,20 +859,98 @@ public class DataIoAdapter { for (Attribute attribute : outputColumnList) { Attribute resAttribute = NdpUtils.getColumnAttribute(attribute, listAtt); String columnName = resAttribute.name().toLowerCase(Locale.ENGLISH); - Type type = NdpUtils.transOlkDataType(resAttribute.dataType(), false); + Type type = NdpUtils.transOlkDataType(resAttribute.dataType(), resAttribute, false); int columnId = NdpUtils.getColumnId(resAttribute.toString()) - columnOffset; isPartitionKey = partitionColumnName.contains(columnName); String partitionValue = NdpUtils.getPartitionValue(filePath, columnName); omnidataColumns.add(new Column(columnId, columnName, type, isPartitionKey, partitionValue)); omnidataTypes.add(type); - filterTypesList.add(NdpUtils.transDataIoDataType(resAttribute.dataType())); + filterTypesList.add(NdpUtils.transDecodeType(resAttribute.dataType())); filterOrdersList.add(filterColumnId); omnidataProjections.add(new InputReferenceExpression(filterColumnId, type)); fieldMap.put(columnName, filterColumnId); ++filterColumnId; } } -} + public boolean isOperatorCombineEnabled() { + return isOperatorCombineEnabled; + } + private void initCandidatesBeforeDomain(Seq filterOutPut) { + omnidataTypes.clear(); + omnidataColumns.clear(); + omnidataProjections.clear(); + columnNameSet.clear(); + columnTypesList.clear(); + columnOrdersList.clear(); + fieldMap.clear(); + filterTypesList.clear(); + filterOrdersList.clear(); + columnNameMap.clear(); + columnOrder = 0; + partitionColumnName.clear(); + listAtt = JavaConverters.seqAsJavaList(filterOutPut); + isPushDownAgg = true; + } + + public ImmutableMap buildDomains( + Seq sparkOutPut, + Seq partitionColumn, + Seq filterOutPut, + PushDownInfo pushDownOperators, + TaskContext context) { + // initCandidates + initCandidatesBeforeDomain(filterOutPut); + + // add partition column + JavaConverters.seqAsJavaList(partitionColumn).forEach(a -> partitionColumnName.add(a.name())); + + // init column info + if (pushDownOperators.aggExecutions().size() == 0) { + isPushDownAgg = false; + initColumnInfo(sparkOutPut); + } + + // create filter + Optional filterRowExpression = initFilter(pushDownOperators.filterExecutions()); + + long startTime = System.currentTimeMillis(); + ImmutableMap.Builder domains = ImmutableMap.builder(); + if (filterRowExpression.isPresent() && NdpConf.getNdpDomainGenerateEnable(context)) { + ConnectorSession session = MetaStore.getConnectorSession(); + RowExpressionDomainTranslator domainTranslator = new RowExpressionDomainTranslator(MetaStore.getMetadata()); + DomainTranslator.ColumnExtractor columnExtractor = (expression, domain) -> { + if (expression instanceof InputReferenceExpression) { + return Optional.of((InputReferenceExpression) expression); + } + return Optional.empty(); + }; + DomainTranslator.ExtractionResult extractionResult = domainTranslator + .fromPredicate(session, filterRowExpression.get(), columnExtractor); + if (!extractionResult.getTupleDomain().isNone()) { + extractionResult.getTupleDomain().getDomains().get().forEach((columnHandle, domain) -> { + Type type = domain.getType(); + // unSupport dataType skip + if (type instanceof MapType || + type instanceof ArrayType || + type instanceof RowType || + type instanceof DecimalType || + type instanceof TimestampType) { + return; + } + + domains.put(omnidataColumns.get(columnHandle.getField()).getName(), domain); + }); + } + } + + ImmutableMap domainImmutableMap = domains.build(); + long costTime = System.currentTimeMillis() - startTime; + if (LOG.isDebugEnabled()) { + LOG.debug("Push down generate domain cost time:" + costTime + ";generate domain:" + domainImmutableMap.size()); + } + return domainImmutableMap; + } +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/MetaStore.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/MetaStore.java new file mode 100644 index 0000000000000000000000000000000000000000..0016eaed335a392d3b5671f2f75a54159cb34f34 --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/MetaStore.java @@ -0,0 +1,128 @@ +/* + * Copyright (C) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql; + +import com.esotericsoftware.kryo.Kryo; +import com.google.common.collect.ImmutableSet; +import io.prestosql.metadata.*; +import io.prestosql.spi.connector.ConnectorSession; +import io.prestosql.spi.security.ConnectorIdentity; +import io.prestosql.spi.type.TimeZoneKey; +import io.prestosql.sql.analyzer.FeaturesConfig; +import io.prestosql.transaction.NoOpTransactionManager; +import io.prestosql.transaction.TransactionManager; + +import java.util.Locale; +import java.util.Optional; +import java.util.TimeZone; + +/** + * Used to initialize some common classes + * + * @since 2023.04 + */ +public class MetaStore { + private static final Metadata metadata = initCompiler(); + private static final ConnectorSession connectorSession = initConnectorSession(); + + private MetaStore() { + } + + private static Metadata initCompiler() { + FeaturesConfig featuresConfig = new FeaturesConfig(); + TransactionManager transactionManager = new NoOpTransactionManager(); + return new MetadataManager( + new FunctionAndTypeManager( + transactionManager, + featuresConfig, + new HandleResolver(), + ImmutableSet.of(), + new Kryo()), + featuresConfig, + new SessionPropertyManager(), + new SchemaPropertyManager(), + new TablePropertyManager(), + new ColumnPropertyManager(), + new AnalyzePropertyManager(), + transactionManager, + null); + } + + /** + * get Metadata instance + * + * @return Metadata + */ + public static Metadata getMetadata() { + return metadata; + } + + private static ConnectorSession initConnectorSession() { + return new ConnectorSession() { + @Override + public String getQueryId() { + return "null"; + } + + @Override + public Optional getSource() { + return Optional.empty(); + } + + @Override + public ConnectorIdentity getIdentity() { + return null; + } + + @Override + public TimeZoneKey getTimeZoneKey() { + return TimeZoneKey.getTimeZoneKey(TimeZone.getDefault().getID()); + } + + @Override + public Locale getLocale() { + return Locale.getDefault(); + } + + @Override + public Optional getTraceToken() { + return Optional.empty(); + } + + @Override + public long getStartTime() { + return 0; + } + + @Override + public T getProperty(String name, Class cls) { + return null; + } + }; + } + + /** + * get ConnectorSession instance + * + * @return ConnectorSession + */ + public static ConnectorSession getConnectorSession() { + return connectorSession; + } +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpFilterUtils.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpFilterUtils.java index 2897ec183ebae74f69e6cbaeb16631db4c12b284..1252073e2034ea2adc0906e838f58d1ca366bdf4 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpFilterUtils.java +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpFilterUtils.java @@ -32,14 +32,13 @@ import io.prestosql.spi.relation.RowExpression; import io.prestosql.spi.relation.SpecialForm; import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeSignature; -import io.prestosql.spi.type.TypeSignatureParameter; import org.apache.spark.sql.catalyst.expressions.Expression; -import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Optional; /** * NdpFilterUtils @@ -56,99 +55,53 @@ public class NdpFilterUtils { } public static RowExpression generateRowExpression( - String signatureName, PrestoExpressionInfo expressionInfo, - Type prestoType, int filterProjectionId, - List argumentValues, - List multiArguments, String operatorName) { + String operatorName, + List arguments, + TypeSignature typeSignature) { RowExpression rowExpression; - List rowArguments; - String prestoName = prestoType.toString(); - TypeSignature paramRight; - TypeSignature paramLeft; - if (prestoType.toString().contains("decimal")) { - String[] parameter = prestoName.split("\\(")[1].split("\\)")[0].split(","); - long precision = Long.parseLong(parameter[0]); - long scale = Long.parseLong(parameter[1]); - paramRight = new TypeSignature("decimal", TypeSignatureParameter.of(precision), TypeSignatureParameter.of(scale)); - paramLeft = new TypeSignature("decimal", TypeSignatureParameter.of(precision), TypeSignatureParameter.of(scale)); - } else { - paramRight = new TypeSignature(prestoName); - paramLeft = new TypeSignature(prestoName); - } - Signature signature = new Signature( - QualifiedObjectName.valueOfDefaultFunction("$operator$" + - signatureName.toLowerCase(Locale.ENGLISH)), - FunctionKind.SCALAR, new TypeSignature("boolean"), - paramRight, paramLeft); switch (operatorName.toLowerCase(Locale.ENGLISH)) { case "is_null": - List notnullArguments = new ArrayList<>(); - if (expressionInfo.isUDF()) { - notnullArguments.add(expressionInfo.getPrestoRowExpression()); - } else { - notnullArguments.add(new InputReferenceExpression(filterProjectionId, prestoType)); - } - rowExpression = new SpecialForm(IS_NULL, BOOLEAN, notnullArguments); + rowExpression = new SpecialForm(IS_NULL, BOOLEAN, arguments); break; case "in": - rowArguments = getConstantArguments(prestoType, argumentValues, filterProjectionId); - rowExpression = new SpecialForm(IN, BOOLEAN, rowArguments); + rowExpression = new SpecialForm(IN, BOOLEAN, arguments); break; - case "multy_columns": - Signature signatureMulti = new Signature( - QualifiedObjectName.valueOfDefaultFunction("$operator$" - + signatureName.toLowerCase(Locale.ENGLISH)), - FunctionKind.SCALAR, new TypeSignature("boolean"), - new TypeSignature(prestoType.toString()), - new TypeSignature(prestoType.toString())); - rowExpression = new CallExpression(signatureName, - new BuiltInFunctionHandle(signatureMulti), BOOLEAN, multiArguments); + case "not": + Signature notSignature = new Signature(QualifiedObjectName.valueOfDefaultFunction("not"), + FunctionKind.SCALAR, new TypeSignature("boolean"), typeSignature); + rowExpression = new CallExpression(operatorName.toLowerCase(Locale.ENGLISH), + new BuiltInFunctionHandle(notSignature), BOOLEAN, arguments); break; case "isempty": case "isdeviceidlegal": case "ismessycode": case "dateudf": - rowExpression = expressionInfo.getPrestoRowExpression(); + rowExpression = arguments.get(0); break; default: - if (expressionInfo.getReturnType() != null) { - rowArguments = getUdfArguments(prestoType, - argumentValues, expressionInfo.getPrestoRowExpression()); - } else { - rowArguments = getConstantArguments(prestoType, - argumentValues, filterProjectionId); - } - rowExpression = new CallExpression(signatureName, - new BuiltInFunctionHandle(signature), BOOLEAN, rowArguments); + Signature signature = new Signature( + QualifiedObjectName.valueOfDefaultFunction("$operator$" + + operatorName.toLowerCase(Locale.ENGLISH)), + FunctionKind.SCALAR, new TypeSignature("boolean"), typeSignature, typeSignature); + // To adapt to the omniOperator, use uppercase operatorName + rowExpression = new CallExpression(operatorName, + new BuiltInFunctionHandle(signature), BOOLEAN, arguments); break; } return rowExpression; } - public static List getConstantArguments(Type typeStr, - List argumentValues, - int columnId) { - List arguments = new ArrayList<>(); - arguments.add(new InputReferenceExpression(columnId, typeStr)); - if (null != argumentValues && argumentValues.size() > 0) { - for (Object argumentValue : argumentValues) { - arguments.add(NdpUtils - .transArgumentData(argumentValue.toString(), typeStr)); - } - } - return arguments; - } - - public static List getUdfArguments(Type typeStr, List argumentValues, - RowExpression callExpression) { - List arguments = new ArrayList<>(); - arguments.add(callExpression); - if (null != argumentValues && argumentValues.size() > 0) { - for (Object argumentValue : argumentValues) { - arguments.add(NdpUtils - .transArgumentData(argumentValue.toString(), typeStr)); - } + /** + * create RowExpression for column + * + * @param columnInfo column info + * @return RowExpression produced by column info + */ + public static RowExpression createRowExpressionForColumn(ColumnInfo columnInfo) { + if (columnInfo.getExpressionInfo().getReturnType() != null) { + return columnInfo.getExpressionInfo().getPrestoRowExpression(); + } else { + return new InputReferenceExpression(columnInfo.getFilterProjectionId(), columnInfo.getPrestoType()); } - return arguments; } } diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpUdfEnum.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpUdfEnum.java index 02185f29372e28647b8477a366739f9d30ac4e8f..c9d39027f0ef367d2a6524e20bc9d53a48305756 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpUdfEnum.java +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpUdfEnum.java @@ -27,7 +27,7 @@ public enum NdpUdfEnum { LENGTH("length","length"), UPPER("upper","upper"), LOWER("lower","lower"), - CAST("cast","$operator$cast"), + CAST("CAST","$operator$cast"), REPLACE("replace","replace"), INSTR("instr","instr"), SUBSCRIPT("SUBSCRIPT","$operator$subscript"), diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpUdfExpressions.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpUdfExpressions.java index 745b1fb219b31732ef8c18e8d266f2d5ad838889..0567d029b580b35e642c79048d0a6fc1fd892960 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpUdfExpressions.java +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpUdfExpressions.java @@ -26,20 +26,9 @@ import io.prestosql.spi.relation.CallExpression; import io.prestosql.spi.relation.InputReferenceExpression; import io.prestosql.spi.relation.RowExpression; import io.prestosql.spi.type.*; +import org.apache.spark.sql.catalyst.expressions.*; import scala.collection.JavaConverters; -import org.apache.spark.sql.catalyst.expressions.AttributeReference; -import org.apache.spark.sql.catalyst.expressions.Cast; -import org.apache.spark.sql.catalyst.expressions.Expression; -import org.apache.spark.sql.catalyst.expressions.GetArrayItem; -import org.apache.spark.sql.catalyst.expressions.Length; -import org.apache.spark.sql.catalyst.expressions.Literal; -import org.apache.spark.sql.catalyst.expressions.Lower; -import org.apache.spark.sql.catalyst.expressions.StringInstr; -import org.apache.spark.sql.catalyst.expressions.StringReplace; -import org.apache.spark.sql.catalyst.expressions.StringSplit; -import org.apache.spark.sql.catalyst.expressions.Substring; -import org.apache.spark.sql.catalyst.expressions.Upper; import org.apache.spark.sql.hive.HiveSimpleUDF; import java.util.ArrayList; @@ -47,6 +36,8 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import static io.prestosql.spi.type.VarcharType.createVarcharType; + /** * Used to process Spark`s UDF, which is converted to presto. * @@ -55,17 +46,19 @@ import java.util.Map; public class NdpUdfExpressions { private void checkAttributeReference(Expression childExpression, - PrestoExpressionInfo prestoExpressionInfo, - Map fieldMap, Type childType, List rowArguments) { + PrestoExpressionInfo prestoExpressionInfo, + Map fieldMap, Type childType, + List rowArguments) { if ((childExpression instanceof AttributeReference)) { int lengthProjectId = NdpFilterUtils.getFilterProjectionId(childExpression, fieldMap); rowArguments.add(new InputReferenceExpression(lengthProjectId, childType)); prestoExpressionInfo.setProjectionId(lengthProjectId); prestoExpressionInfo.setFieldDataType( - NdpUtils.transOlkDataType(childExpression.dataType(), false)); + NdpUtils.transOlkDataType(childExpression.dataType(), childExpression, false)); prestoExpressionInfo.setChildExpression(childExpression); } else if (childExpression instanceof Literal) { - rowArguments.add(NdpUtils.transArgumentData(((Literal) childExpression).value().toString(), childType)); + rowArguments.add(NdpUtils.transConstantExpression(((Literal) childExpression).value().toString(), + childType)); } else { createNdpUdf(childExpression, prestoExpressionInfo, fieldMap); rowArguments.add(prestoExpressionInfo.getPrestoRowExpression()); @@ -76,7 +69,7 @@ public class NdpUdfExpressions { * create Udf */ public void createNdpUdf(Expression udfExpression, PrestoExpressionInfo prestoExpressionInfo, - Map fieldMap) { + Map fieldMap) { if (udfExpression instanceof Length) { createNdpLength((Length) udfExpression, prestoExpressionInfo, fieldMap); } else if (udfExpression instanceof Upper) { @@ -106,215 +99,237 @@ public class NdpUdfExpressions { * Used to create UDF with only a single parameter */ private void createNdpSingleParameter(NdpUdfEnum udfEnum, - Expression expression, Expression childExpression, - PrestoExpressionInfo prestoExpressionInfo, Map fieldMap) { + Expression expression, Expression childExpression, + PrestoExpressionInfo prestoExpressionInfo, + Map fieldMap) { String signatureName = udfEnum.getSignatureName(); - Type childType = NdpUtils.transOlkDataType(childExpression.dataType(), true); + Type childType = NdpUtils.transOlkDataType(childExpression.dataType(), childExpression, true); + if (childType instanceof CharType) { + childType = createVarcharType(((CharType) childType).getLength()); + } Type returnType = NdpUtils.transOlkDataType(expression.dataType(), true); List rowArguments = new ArrayList<>(); checkAttributeReference(childExpression, - prestoExpressionInfo, fieldMap, childType, rowArguments); - //add decimal TypeSignature judgment - TypeSignature inputParamTypeSignature = NdpUtils.createTypeSignature(childType); - TypeSignature returnParamTypeSignature = NdpUtils.createTypeSignature(returnType); + prestoExpressionInfo, fieldMap, childType, rowArguments); Signature signature = new Signature( - QualifiedObjectName.valueOfDefaultFunction(udfEnum.getOperatorName()), - FunctionKind.SCALAR, returnParamTypeSignature,inputParamTypeSignature); + QualifiedObjectName.valueOfDefaultFunction(udfEnum.getOperatorName()), + FunctionKind.SCALAR, returnType.getTypeSignature(), childType.getTypeSignature()); RowExpression resExpression = new CallExpression( - signatureName, new BuiltInFunctionHandle(signature), - returnType, rowArguments); + signatureName, new BuiltInFunctionHandle(signature), + returnType, rowArguments); prestoExpressionInfo.setReturnType(returnType); prestoExpressionInfo.setPrestoRowExpression(resExpression); } private void createNdpLength(Length expression, PrestoExpressionInfo prestoExpressionInfo, - Map fieldMap) { + Map fieldMap) { createNdpSingleParameter(NdpUdfEnum.LENGTH, - expression, expression.child(), prestoExpressionInfo, fieldMap); + expression, expression.child(), prestoExpressionInfo, fieldMap); } private void createNdpUpper(Upper expression, PrestoExpressionInfo prestoExpressionInfo, - Map fieldMap) { + Map fieldMap) { createNdpSingleParameter(NdpUdfEnum.UPPER, - expression, expression.child(), prestoExpressionInfo, fieldMap); + expression, expression.child(), prestoExpressionInfo, fieldMap); } private void createNdpLower(Lower expression, PrestoExpressionInfo prestoExpressionInfo, - Map fieldMap) { + Map fieldMap) { createNdpSingleParameter(NdpUdfEnum.LOWER, - expression, expression.child(), prestoExpressionInfo, fieldMap); + expression, expression.child(), prestoExpressionInfo, fieldMap); } private void createNdpCast(Cast expression, PrestoExpressionInfo prestoExpressionInfo, - Map fieldMap) { + Map fieldMap) { createNdpSingleParameter(NdpUdfEnum.CAST, - expression, expression.child(), prestoExpressionInfo, fieldMap); + expression, expression.child(), prestoExpressionInfo, fieldMap); } private void createHiveSimpleUdf(Expression hiveSimpleUDFExpression, - PrestoExpressionInfo prestoExpressionInfo, - Map fieldMap) { + PrestoExpressionInfo prestoExpressionInfo, + Map fieldMap) { String signatureName = ((HiveSimpleUDF) hiveSimpleUDFExpression).name(); List hiveSimpleUdf = JavaConverters.seqAsJavaList( - hiveSimpleUDFExpression.children()); + hiveSimpleUDFExpression.children()); Type returnType = NdpUtils.transOlkDataType( - hiveSimpleUDFExpression.dataType(), false); + hiveSimpleUDFExpression.dataType(), false); List rowArguments = new ArrayList<>(); Type strTypeCandidate = returnType; Signature signature; for (Expression hiveUdf : hiveSimpleUdf) { strTypeCandidate = NdpUtils.transOlkDataType(hiveUdf.dataType(), false); checkAttributeReference(hiveUdf, prestoExpressionInfo, - fieldMap, strTypeCandidate, rowArguments); + fieldMap, strTypeCandidate, rowArguments); } if (hiveSimpleUdf.size() > 0) { - TypeSignature returnTypeSignature = NdpUtils.createTypeSignature(returnType); TypeSignature[] inputTypeSignatures = new TypeSignature[hiveSimpleUdf.size()]; for (int i = 0; i < hiveSimpleUdf.size(); i++) { - inputTypeSignatures[i] = NdpUtils.createTypeSignature(hiveSimpleUdf.get(i).dataType(), false); + Type type = NdpUtils.transOlkDataType(hiveSimpleUdf.get(i).dataType(), false); + inputTypeSignatures[i] = type.getTypeSignature(); } signature = new Signature( - //TODO QualifiedObjectName.valueOf("hive", "default", signatureName), - FunctionKind.SCALAR, returnTypeSignature, + FunctionKind.SCALAR, returnType.getTypeSignature(), inputTypeSignatures); } else { throw new UnsupportedOperationException("The number of UDF parameters is invalid."); } - //TODO signatureName = "hive.default." + signatureName.toLowerCase(Locale.ENGLISH); RowExpression resExpression = new CallExpression(signatureName.toLowerCase(Locale.ENGLISH), - new BuiltInFunctionHandle(signature), returnType, rowArguments); + new BuiltInFunctionHandle(signature), returnType, rowArguments); prestoExpressionInfo.setReturnType(returnType); prestoExpressionInfo.setUDF(true); prestoExpressionInfo.setPrestoRowExpression(resExpression); } private void createNdpSubstring(Substring expression, PrestoExpressionInfo prestoExpressionInfo, - Map fieldMap) { + Map fieldMap) { String signatureName = NdpUdfEnum.SUBSTRING.getSignatureName(); - Type strType = NdpUtils.transOlkDataType(expression.str().dataType(), true); + Type strType = NdpUtils.transOlkDataType(expression.str().dataType(), expression.str(), + true); + if (strType instanceof CharType) { + strType = createVarcharType(((CharType) strType).getLength()); + } Type lenType = NdpUtils.transOlkDataType(expression.len().dataType(), true); Type posType = NdpUtils.transOlkDataType(expression.pos().dataType(), true); Type returnType = NdpUtils.transOlkDataType(expression.dataType(), true); + List rowArguments = new ArrayList<>(); checkAttributeReference(expression.str(), - prestoExpressionInfo, fieldMap, strType, rowArguments); - rowArguments.add(NdpUtils.transArgumentData( - expression.pos().toString(), posType)); - rowArguments.add(NdpUtils.transArgumentData( - expression.len().toString(), lenType)); + prestoExpressionInfo, fieldMap, strType, rowArguments); + String startIndex = "0".equals(expression.pos().toString()) ? "1" : expression.pos().toString(); + rowArguments.add(NdpUtils.transConstantExpression( + startIndex, posType)); + rowArguments.add(NdpUtils.transConstantExpression( + expression.len().toString(), lenType)); Signature signature = new Signature( - QualifiedObjectName.valueOfDefaultFunction( - NdpUdfEnum.SUBSTRING.getOperatorName()), FunctionKind.SCALAR, - new TypeSignature(returnType.toString()), new TypeSignature(strType.toString()), - new TypeSignature(posType.toString()), new TypeSignature(lenType.toString())); + QualifiedObjectName.valueOfDefaultFunction( + NdpUdfEnum.SUBSTRING.getOperatorName()), FunctionKind.SCALAR, + returnType.getTypeSignature(), strType.getTypeSignature(), + posType.getTypeSignature(), lenType.getTypeSignature()); RowExpression resExpression = new CallExpression( - signatureName, new BuiltInFunctionHandle(signature), - returnType, rowArguments); + signatureName, new BuiltInFunctionHandle(signature), + returnType, rowArguments); prestoExpressionInfo.setPrestoRowExpression(resExpression); prestoExpressionInfo.setReturnType(returnType); } private void createNdpReplace(StringReplace expression, - PrestoExpressionInfo prestoExpressionInfo, - Map fieldMap) { + PrestoExpressionInfo prestoExpressionInfo, + Map fieldMap) { String signatureName = NdpUdfEnum.REPLACE.getSignatureName(); - Type srcType = NdpUtils.transOlkDataType(expression.srcExpr().dataType(), true); - Type searchType = NdpUtils.transOlkDataType( - expression.searchExpr().dataType(), true); - Type replaceType = NdpUtils.transOlkDataType( - expression.replaceExpr().dataType(), true); + Type srcType = NdpUtils.transOlkDataType(expression.srcExpr().dataType(), expression.srcExpr(), + true); + if (srcType instanceof CharType) { + srcType = createVarcharType(((CharType) srcType).getLength()); + } + Type searchType = NdpUtils.transOlkDataType(expression.searchExpr().dataType(), true); + Type replaceType = NdpUtils.transOlkDataType(expression.replaceExpr().dataType(), true); Type returnType = NdpUtils.transOlkDataType(expression.dataType(), true); + List rowArguments = new ArrayList<>(); checkAttributeReference(expression.srcExpr(), - prestoExpressionInfo, fieldMap, srcType, rowArguments); - rowArguments.add(NdpUtils.transArgumentData( - expression.searchExpr().toString(), searchType)); - rowArguments.add(NdpUtils.transArgumentData( - expression.replaceExpr().toString(), replaceType)); + prestoExpressionInfo, fieldMap, srcType, rowArguments); + rowArguments.add(NdpUtils.transConstantExpression( + expression.searchExpr().toString(), searchType)); + rowArguments.add(NdpUtils.transConstantExpression( + expression.replaceExpr().toString(), replaceType)); Signature signature = new Signature( - QualifiedObjectName.valueOfDefaultFunction( - NdpUdfEnum.REPLACE.getOperatorName()), FunctionKind.SCALAR, - new TypeSignature(returnType.toString()), new TypeSignature(srcType.toString()), - new TypeSignature(searchType.toString()), new TypeSignature(replaceType.toString())); + QualifiedObjectName.valueOfDefaultFunction( + NdpUdfEnum.REPLACE.getOperatorName()), FunctionKind.SCALAR, + returnType.getTypeSignature(), srcType.getTypeSignature(), + searchType.getTypeSignature(), replaceType.getTypeSignature()); RowExpression resExpression = new CallExpression( - signatureName, new BuiltInFunctionHandle(signature), - returnType, rowArguments); + signatureName, new BuiltInFunctionHandle(signature), + returnType, rowArguments); prestoExpressionInfo.setReturnType(returnType); prestoExpressionInfo.setPrestoRowExpression(resExpression); } private void createNdpInstr(StringInstr expression, PrestoExpressionInfo prestoExpressionInfo, - Map fieldMap) { + Map fieldMap) { String signatureName = NdpUdfEnum.INSTR.getSignatureName(); - Type strType = NdpUtils.transOlkDataType(expression.str().dataType(), true); + Type strType = NdpUtils.transOlkDataType(expression.str().dataType(), expression.str(), + true); + if (strType instanceof CharType) { + strType = createVarcharType(((CharType) strType).getLength()); + } Type substrType = NdpUtils.transOlkDataType(expression.substr().dataType(), true); Type returnType = NdpUtils.transOlkDataType(expression.dataType(), true); + List rowArguments = new ArrayList<>(); checkAttributeReference(expression.str(), - prestoExpressionInfo, fieldMap, strType, rowArguments); - rowArguments.add(NdpUtils.transArgumentData( - expression.substr().toString(), substrType)); + prestoExpressionInfo, fieldMap, strType, rowArguments); + rowArguments.add(NdpUtils.transConstantExpression( + expression.substr().toString(), substrType)); Signature signature = new Signature( - QualifiedObjectName.valueOfDefaultFunction( - NdpUdfEnum.INSTR.getOperatorName()), FunctionKind.SCALAR, - new TypeSignature(returnType.toString()), new TypeSignature(strType.toString()), - new TypeSignature(substrType.toString())); + QualifiedObjectName.valueOfDefaultFunction( + NdpUdfEnum.INSTR.getOperatorName()), FunctionKind.SCALAR, + returnType.getTypeSignature(), strType.getTypeSignature(), + substrType.getTypeSignature()); RowExpression resExpression = new CallExpression( - signatureName, new BuiltInFunctionHandle(signature), - returnType, rowArguments); + signatureName, new BuiltInFunctionHandle(signature), + returnType, rowArguments); prestoExpressionInfo.setReturnType(returnType); prestoExpressionInfo.setPrestoRowExpression(resExpression); } private void createNdpSplit(StringSplit expression, PrestoExpressionInfo prestoExpressionInfo, - Map fieldMap) { + Map fieldMap) { String signatureName = NdpUdfEnum.SPLIT.getSignatureName(); - Type strType = NdpUtils.transOlkDataType(expression.str().dataType(), true); + Type strType = NdpUtils.transOlkDataType(expression.str().dataType(), expression.str(), + true); + if (strType instanceof CharType) { + strType = createVarcharType(((CharType) strType).getLength()); + } Type regexType = NdpUtils.transOlkDataType(expression.regex().dataType(), true); Type returnType = NdpUtils.transOlkDataType(expression.dataType(), true); + List rowArguments = new ArrayList<>(); checkAttributeReference(expression.str(), - prestoExpressionInfo, fieldMap, strType, rowArguments); - rowArguments.add(NdpUtils.transArgumentData( - expression.regex().toString(), regexType)); + prestoExpressionInfo, fieldMap, strType, rowArguments); + rowArguments.add(NdpUtils.transConstantExpression( + expression.regex().toString(), regexType)); Signature signature = new Signature( - QualifiedObjectName.valueOfDefaultFunction( - NdpUdfEnum.SPLIT.getOperatorName()), FunctionKind.SCALAR, - new TypeSignature(returnType.toString()), new TypeSignature(strType.toString()), - new TypeSignature(regexType.toString())); + QualifiedObjectName.valueOfDefaultFunction( + NdpUdfEnum.SPLIT.getOperatorName()), FunctionKind.SCALAR, + returnType.getTypeSignature(), strType.getTypeSignature(), + regexType.getTypeSignature()); RowExpression resExpression = new CallExpression( - signatureName, new BuiltInFunctionHandle(signature), - returnType, rowArguments); + signatureName, new BuiltInFunctionHandle(signature), + returnType, rowArguments); prestoExpressionInfo.setReturnType(returnType); prestoExpressionInfo.setPrestoRowExpression(resExpression); } private void createNdpSubscript(GetArrayItem expression, - PrestoExpressionInfo prestoExpressionInfo, - Map fieldMap) { + PrestoExpressionInfo prestoExpressionInfo, + Map fieldMap) { String signatureName = NdpUdfEnum.SUBSCRIPT.getSignatureName(); - Type strType = NdpUtils.transOlkDataType(expression.child().dataType(), true); - Type ordinalType = NdpUtils.transOlkDataType( - expression.ordinal().dataType(), true); + Type strType = NdpUtils.transOlkDataType(expression.child().dataType(), expression.child(), + true); + if (strType instanceof CharType) { + strType = createVarcharType(((CharType) strType).getLength()); + } + Type ordinalType = NdpUtils.transOlkDataType(expression.ordinal().dataType(), true); Type returnType = NdpUtils.transOlkDataType(expression.dataType(), true); + List rowArguments = new ArrayList<>(); checkAttributeReference(expression.child(), - prestoExpressionInfo, fieldMap, strType, rowArguments); + prestoExpressionInfo, fieldMap, strType, rowArguments); // The presto`s array subscript is initially 1. int argumentValue = Integer.parseInt( - ((Literal) expression.ordinal()).value().toString()) + 1; - rowArguments.add(NdpUtils.transArgumentData( - Integer.toString(argumentValue), ordinalType)); + ((Literal) expression.ordinal()).value().toString()) + 1; + rowArguments.add(NdpUtils.transConstantExpression( + Integer.toString(argumentValue), ordinalType)); Signature signature = new Signature( - QualifiedObjectName.valueOfDefaultFunction( - NdpUdfEnum.SUBSCRIPT.getOperatorName()), FunctionKind.SCALAR, - new TypeSignature(returnType.toString()), new TypeSignature(strType.toString()), - new TypeSignature(ordinalType.toString())); + QualifiedObjectName.valueOfDefaultFunction( + NdpUdfEnum.SUBSCRIPT.getOperatorName()), FunctionKind.SCALAR, + returnType.getTypeSignature(), strType.getTypeSignature(), + ordinalType.getTypeSignature()); RowExpression resExpression = new CallExpression( - signatureName, new BuiltInFunctionHandle(signature), - returnType, rowArguments); + signatureName, new BuiltInFunctionHandle(signature), + returnType, rowArguments); prestoExpressionInfo.setReturnType(returnType); prestoExpressionInfo.setPrestoRowExpression(resExpression); } diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpUtils.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpUtils.java index 7333e4df1eb80d18357661fb59b89ecce30257a1..1e787d7c0567f8fe5fe0267b10ab23fa236d4331 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpUtils.java +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/NdpUtils.java @@ -18,50 +18,73 @@ package org.apache.spark.sql; -import com.huawei.boostkit.omnidata.decode.type.*; +import static io.airlift.slice.Slices.utf8Slice; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.DateType.DATE; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.RealType.REAL; +import static io.prestosql.spi.type.SmallintType.SMALLINT; +import static io.prestosql.spi.type.TimestampType.TIMESTAMP; +import static io.prestosql.spi.type.TinyintType.TINYINT; +import static io.prestosql.spi.type.VarcharType.*; +import static java.lang.Float.floatToIntBits; +import static java.lang.Float.parseFloat; + +import com.huawei.boostkit.omnidata.decode.type.BooleanDecodeType; +import com.huawei.boostkit.omnidata.decode.type.ByteDecodeType; +import com.huawei.boostkit.omnidata.decode.type.DateDecodeType; +import com.huawei.boostkit.omnidata.decode.type.DecimalDecodeType; +import com.huawei.boostkit.omnidata.decode.type.DecodeType; +import com.huawei.boostkit.omnidata.decode.type.DoubleDecodeType; +import com.huawei.boostkit.omnidata.decode.type.FloatDecodeType; +import com.huawei.boostkit.omnidata.decode.type.IntDecodeType; +import com.huawei.boostkit.omnidata.decode.type.LongDecodeType; +import com.huawei.boostkit.omnidata.decode.type.LongToByteDecodeType; +import com.huawei.boostkit.omnidata.decode.type.LongToFloatDecodeType; +import com.huawei.boostkit.omnidata.decode.type.LongToIntDecodeType; +import com.huawei.boostkit.omnidata.decode.type.LongToShortDecodeType; +import com.huawei.boostkit.omnidata.decode.type.ShortDecodeType; +import com.huawei.boostkit.omnidata.decode.type.TimestampDecodeType; +import com.huawei.boostkit.omnidata.decode.type.VarcharDecodeType; +import com.huawei.boostkit.omnidata.model.Column; import io.airlift.slice.Slice; import io.prestosql.spi.relation.ConstantExpression; -import io.prestosql.spi.type.*; -import io.prestosql.spi.type.ArrayType; + +import io.prestosql.spi.type.CharType; import io.prestosql.spi.type.DecimalType; +import io.prestosql.spi.type.Decimals; +import io.prestosql.spi.type.StandardTypes; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.type.VarcharType; +import org.apache.spark.sql.catalyst.expressions.Attribute; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import scala.Option; +import scala.collection.JavaConverters; +import scala.collection.Seq; -import org.apache.spark.sql.catalyst.expressions.*; import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction; +import org.apache.spark.sql.catalyst.util.CharVarcharUtils; import org.apache.spark.sql.execution.ndp.AggExeInfo; import org.apache.spark.sql.execution.ndp.LimitExeInfo; -import org.apache.spark.sql.types.*; -import org.apache.spark.sql.types.DateType; - -import scala.Option; -import scala.collection.JavaConverters; -import scala.collection.Seq; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.lang.reflect.Field; import java.math.BigDecimal; import java.math.BigInteger; -import java.time.format.DateTimeFormatter; -import java.time.format.DateTimeParseException; -import java.time.format.ResolverStyle; import java.util.HashMap; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; -import static io.airlift.slice.Slices.utf8Slice; -import static io.prestosql.spi.type.BigintType.BIGINT; -import static io.prestosql.spi.type.BooleanType.BOOLEAN; -import static io.prestosql.spi.type.DateType.DATE; -import static io.prestosql.spi.type.DoubleType.DOUBLE; -import static io.prestosql.spi.type.IntegerType.INTEGER; -import static io.prestosql.spi.type.RealType.REAL; -import static io.prestosql.spi.type.SmallintType.SMALLINT; -import static io.prestosql.spi.type.TimestampType.TIMESTAMP; -import static io.prestosql.spi.type.TinyintType.TINYINT; -import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static java.lang.Float.floatToIntBits; -import static java.lang.Float.parseFloat; - /** * NdpUtils * @@ -69,6 +92,26 @@ import static java.lang.Float.parseFloat; */ public class NdpUtils { + /** + * Types supported by OmniOperator. + */ + public static final Set supportTypes = new HashSet() { + { + add(StandardTypes.INTEGER); + add(StandardTypes.DATE); + add(StandardTypes.SMALLINT); + add(StandardTypes.BIGINT); + add(StandardTypes.VARCHAR); + add(StandardTypes.CHAR); + add(StandardTypes.DECIMAL); + add(StandardTypes.ROW); + add(StandardTypes.DOUBLE); + add(StandardTypes.VARBINARY); + add(StandardTypes.BOOLEAN); + } + }; + private static final Logger LOG = LoggerFactory.getLogger(NdpUtils.class); + public static int getColumnOffset(StructType dataSchema, Seq outPut) { List attributeList = JavaConverters.seqAsJavaList(outPut); String columnName = ""; @@ -92,6 +135,7 @@ public class NdpUtils { Seq aggExeInfo) { String columnName = ""; int columnTempId = 0; + boolean isFind = false; if (aggExeInfo != null && aggExeInfo.size() > 0) { List aggExecutionList = JavaConverters.seqAsJavaList(aggExeInfo); for (AggExeInfo aggExeInfoTemp : aggExecutionList) { @@ -106,15 +150,18 @@ public class NdpUtils { Matcher matcher = pattern.matcher(expression.toString()); if (matcher.find()) { columnTempId = Integer.parseInt(matcher.group(1)); + isFind = true; break; } } - break; + if (isFind) { + break; + } } List namedExpressions = JavaConverters.seqAsJavaList( aggExeInfoTemp.groupingExpressions()); for (NamedExpression namedExpression : namedExpressions) { - columnName = namedExpression.toString().split("#")[0]; + columnName = namedExpression.name(); columnTempId = NdpUtils.getColumnId(namedExpression.toString()); break; } @@ -145,13 +192,35 @@ public class NdpUtils { String adf = columnArrayId.substring(0, columnArrayId.length() - 1); columnTempId = Integer.parseInt(adf); } else { - columnTempId = Integer.parseInt(columnArrayId); + if (columnArrayId.contains(")")) { + columnTempId = Integer.parseInt(columnArrayId.split("\\)")[0].replaceAll("[^(\\d+)]", "")); + } else { + columnTempId = Integer.parseInt(columnArrayId); + } } return columnTempId; } + /** + * transform spark data type to omnidata + * + * @param dataType spark data type + * @param isSparkUdfOperator is spark udf + * @return result type + */ public static Type transOlkDataType(DataType dataType, boolean isSparkUdfOperator) { - String strType = dataType.toString().toLowerCase(Locale.ENGLISH); + return transOlkDataType(dataType, null, isSparkUdfOperator); + } + + public static Type transOlkDataType(DataType dataType, Object attribute, boolean isSparkUdfOperator) { + String strType; + Metadata metadata = Metadata.empty(); + if (attribute instanceof Attribute) { + metadata = ((Attribute) attribute).metadata(); + strType = ((Attribute) attribute).dataType().toString().toLowerCase(Locale.ENGLISH); + } else { + strType = dataType.toString().toLowerCase(Locale.ENGLISH); + } if (isSparkUdfOperator && "integertype".equalsIgnoreCase(strType)) { strType = "longtype"; } @@ -179,23 +248,24 @@ public class NdpUtils { case "booleantype": return BOOLEAN; case "stringtype": - return VARCHAR; + if (CharVarcharUtils.getRawTypeString(metadata).isDefined()) { + String metadataStr = CharVarcharUtils.getRawTypeString(metadata).get(); + Pattern pattern = Pattern.compile("(?<=\\()\\d+(?=\\))"); + Matcher matcher = pattern.matcher(metadataStr); + String len = String.valueOf(UNBOUNDED_LENGTH); + while (matcher.find()) { + len = matcher.group(); + } + if (metadataStr.startsWith("char")) { + return CharType.createCharType(Integer.parseInt(len)); + } else if (metadataStr.startsWith("varchar")) { + return createVarcharType(Integer.parseInt(len)); + } + } else { + return VARCHAR; + } case "datetype": return DATE; - case "arraytype(stringtype,true)": - case "arraytype(stringtype,false)": - return new ArrayType<>(VARCHAR); - case "arraytype(integertype,true)": - case "arraytype(integertype,false)": - case "arraytype(longtype,true)": - case "arraytype(longtype,false)": - return new ArrayType<>(BIGINT); - case "arraytype(floattype,true)": - case "arraytype(floattype,false)": - return new ArrayType<>(REAL); - case "arraytype(doubletype,true)": - case "arraytype(doubletype,false)": - return new ArrayType<>(DOUBLE); default: throw new UnsupportedOperationException("unsupported this type:" + strType); } @@ -232,23 +302,17 @@ public class NdpUtils { if (BOOLEAN.equals(prestoType)) { return new BooleanDecodeType(); } - if (VARCHAR.equals(prestoType)) { + if (VARCHAR.equals(prestoType) || prestoType instanceof CharType) { return new VarcharDecodeType(); } if (DATE.equals(prestoType)) { return new DateDecodeType(); } - throw new RuntimeException("unsupported this prestoType:" + prestoType); + throw new UnsupportedOperationException("unsupported this prestoType:" + prestoType); } - public static DecodeType transDataIoDataType(DataType dataType) { + public static DecodeType transDecodeType(DataType dataType) { String strType = dataType.toString().toLowerCase(Locale.ENGLISH); - if (strType.contains("decimal")) { - String[] decimalInfo = strType.split("\\(")[1].split("\\)")[0].split(","); - int precision = Integer.parseInt(decimalInfo[0]); - int scale = Integer.parseInt(decimalInfo[1]); - return new DecimalDecodeType(precision, scale); - } switch (strType) { case "timestamptype": return new TimestampDecodeType(); @@ -271,61 +335,75 @@ public class NdpUtils { case "datetype": return new DateDecodeType(); default: - throw new RuntimeException("unsupported this type:" + strType); + if (strType.contains("decimal")) { + String[] decimalInfo = strType.split("\\(")[1].split("\\)")[0].split(","); + int precision = Integer.parseInt(decimalInfo[0]); + int scale = Integer.parseInt(decimalInfo[1]); + return new DecimalDecodeType(precision, scale); + } else { + throw new UnsupportedOperationException("unsupported this type:" + strType); + } } } - public static TypeSignature createTypeSignature(DataType type, boolean isPrestoUdfOperator) { - Type realType = NdpUtils.transOlkDataType(type, isPrestoUdfOperator); - return createTypeSignature(realType); - } - - public static TypeSignature createTypeSignature(Type type) { - String typeName = type.toString(); - if (type instanceof DecimalType) { - String[] decimalInfo = typeName.split("\\(")[1].split("\\)")[0].split(","); - long precision = Long.parseLong(decimalInfo[0]); - long scale = Long.parseLong(decimalInfo[1]); - return new TypeSignature("decimal", TypeSignatureParameter.of(precision), TypeSignatureParameter.of(scale)); + /** + * Convert decimal data to a constant expression + * + * @param argumentValue value + * @param decimalType decimalType + * @return ConstantExpression + */ + public static ConstantExpression transDecimalConstant(String argumentValue, + DecimalType decimalType) { + BigInteger bigInteger = + Decimals.rescale(new BigDecimal(argumentValue), decimalType).unscaledValue(); + if (decimalType.isShort()) { + return new ConstantExpression(bigInteger.longValue(), decimalType); } else { - return new TypeSignature(typeName); + Slice argumentValueSlice = Decimals.encodeUnscaledValue(bigInteger); + long[] base = new long[]{argumentValueSlice.getLong(0), argumentValueSlice.getLong(8)}; + try { + Field filed = Slice.class.getDeclaredField("base"); + filed.setAccessible(true); + filed.set(argumentValueSlice, base); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new UnsupportedOperationException("create long decimal data failed"); + } + return new ConstantExpression(argumentValueSlice, decimalType); } } - public static ConstantExpression transArgumentData(String argumentValue, Type argumentType) { - String strType = argumentType.toString().toLowerCase(Locale.ENGLISH); - if (strType.contains("decimal")) { - String[] parameter = strType.split("\\(")[1].split("\\)")[0].split(","); - int precision = Integer.parseInt(parameter[0]); - int scale = Integer.parseInt(parameter[1]); - BigInteger bigInteger = Decimals.rescale(new BigDecimal(argumentValue), (DecimalType) argumentType).unscaledValue(); - if ("ShortDecimalType".equals(argumentType.getClass().getSimpleName())) { //short decimal type - return new ConstantExpression(bigInteger.longValue(), DecimalType.createDecimalType(precision, scale)); - } else if ("LongDecimalType".equals(argumentType.getClass().getSimpleName())) { //long decimal type - Slice argumentValueSlice = Decimals.encodeUnscaledValue(bigInteger); - long[] base = new long[2]; - base[0] = argumentValueSlice.getLong(0); - base[1] = argumentValueSlice.getLong(8); - try { - Field filed = Slice.class.getDeclaredField("base"); - filed.setAccessible(true); - filed.set(argumentValueSlice, base); - } catch (Exception e) { - e.printStackTrace(); - } - return new ConstantExpression(argumentValueSlice, DecimalType.createDecimalType(precision, scale)); - } else { - throw new UnsupportedOperationException("unsupported data type " + argumentType.getClass().getSimpleName()); - } + /** + * Convert data to a constant expression + * process 'null' data + * + * @param argumentValue value + * @param argumentType argumentType + * @return ConstantExpression + */ + public static ConstantExpression transConstantExpression(String argumentValue, Type argumentType) { + if (argumentType instanceof CharType) { + Slice charValue = utf8Slice(stripEnd(argumentValue, " ")); + return new ConstantExpression(charValue, argumentType); + } + if (argumentType instanceof VarcharType) { + Slice charValue = utf8Slice(argumentValue); + return new ConstantExpression(charValue, argumentType); + } + if (argumentValue.equals("null")) { + return new ConstantExpression(null, argumentType); } + if (argumentType instanceof DecimalType) { + return transDecimalConstant(argumentValue, (DecimalType) argumentType); + } + String strType = argumentType.toString().toLowerCase(Locale.ENGLISH); switch (strType) { case "bigint": case "integer": case "date": case "tinyint": case "smallint": - long longValue = Long.parseLong(argumentValue); - return new ConstantExpression(longValue, argumentType); + return new ConstantExpression(Long.parseLong(argumentValue), argumentType); case "real": return new ConstantExpression( (long) floatToIntBits(parseFloat(argumentValue)), argumentType); @@ -333,9 +411,6 @@ public class NdpUtils { return new ConstantExpression(Double.valueOf(argumentValue), argumentType); case "boolean": return new ConstantExpression(Boolean.valueOf(argumentValue), argumentType); - case "varchar": - Slice charValue = utf8Slice(argumentValue); - return new ConstantExpression(charValue, argumentType); case "timestamp": int rawOffset = TimeZone.getDefault().getRawOffset(); long timestampValue; @@ -347,7 +422,8 @@ public class NdpUtils { } } else { int millisecondsDiffMicroseconds = 3; - timestampValue = Long.parseLong(argumentValue.substring(0, argumentValue.length() - millisecondsDiffMicroseconds)) + rawOffset; + timestampValue = Long.parseLong(argumentValue.substring(0, + argumentValue.length() - millisecondsDiffMicroseconds)) + rawOffset; } return new ConstantExpression(timestampValue, argumentType); default: @@ -369,31 +445,6 @@ public class NdpUtils { return resAttribute; } - public static Object transData(String sparkType, String columnValue) { - String strType = sparkType.toLowerCase(Locale.ENGLISH); - switch (strType) { - case "integertype": - return Integer.valueOf(columnValue); - case "bytetype": - return Byte.valueOf(columnValue); - case "shorttype": - return Short.valueOf(columnValue); - case "longtype": - return Long.valueOf(columnValue); - case "floattype": - return (long) floatToIntBits(parseFloat(columnValue)); - case "doubletype": - return Double.valueOf(columnValue); - case "booleantype": - return Boolean.valueOf(columnValue); - case "stringtype": - case "datetype": - return columnValue; - default: - return ""; - } - } - public static OptionalLong convertLimitExeInfo(Option limitExeInfo) { return limitExeInfo.isEmpty() ? OptionalLong.empty() : OptionalLong.of(limitExeInfo.get().limit()); @@ -420,23 +471,41 @@ public class NdpUtils { return (int) (Math.random() * hostSize); } - public static boolean isValidDateFormat(String dateString) { - boolean isValid = true; - String pattern = "yyyy-MM-dd"; - DateTimeFormatter formatter = DateTimeFormatter.ofPattern(pattern).withResolverStyle(ResolverStyle.STRICT); - try { - formatter.parse(dateString); - } catch (DateTimeParseException e) { - isValid = false; - } - return isValid; + /** + * Check if the input pages contains datatypes unsuppoted by OmniColumnVector. + * + * @param columns Input columns + * @return false if contains unsupported type + */ + public static boolean checkOmniOpColumns(List columns) { + for (Column column : columns) { + String base = column.getType().getTypeSignature().getBase(); + if (!supportTypes.contains(base)) { + LOG.info("Unsupported operator data type {}, rollback", base); + return false; + } + } + return true; } - public static boolean isInDateExpression(Expression expression, String Operator) { - boolean isInDate = false; - if (expression instanceof Cast && Operator.equals("in")) { - isInDate = ((Cast) expression).child().dataType() instanceof DateType; + public static String stripEnd(String str, String stripChars) { + int end; + if (str != null && (end = str.length()) != 0) { + if (stripChars == null) { + while (end != 0 && Character.isWhitespace(str.charAt(end - 1))) { + --end; + } + } else { + if (stripChars.isEmpty()) { + return str; + } + while (end != 0 && stripChars.indexOf(str.charAt(end - 1)) != -1) { + --end; + } + } + return str.substring(0, end); + } else { + return str; } - return isInDate; } -} +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/PageCandidate.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/PageCandidate.java index 8ca14685e342845cd93de93cdf50190fcbf7e7a0..db8dfeef8024284c2a5d9037dfd30612de80e0b0 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/PageCandidate.java +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/PageCandidate.java @@ -31,24 +31,27 @@ public class PageCandidate { public int columnOffset; - public String sdiHosts; + public String pushDownHosts; - private String fileFormat; + private final String fileFormat; public int maxFailedTimes; - private int taskTimeout; + private final int taskTimeout; - public PageCandidate(String filePath, Long startPos, Long splitLen, int columnOffset, - String sdiHosts, String fileFormat, int maxFailedTimes, int taskTimeout) { + private final boolean isOperatorCombineEnabled; + + public PageCandidate(String filePath, Long startPos, Long splitLen, int columnOffset, String pushDownHosts, + String fileFormat, int maxFailedTimes, int taskTimeout, boolean isOperatorCombineEnabled) { this.filePath = filePath; this.startPos = startPos; this.splitLen = splitLen; this.columnOffset = columnOffset; - this.sdiHosts = sdiHosts; + this.pushDownHosts = pushDownHosts; this.fileFormat = fileFormat; this.maxFailedTimes = maxFailedTimes; this.taskTimeout = taskTimeout; + this.isOperatorCombineEnabled = isOperatorCombineEnabled; } public Long getStartPos() { @@ -67,8 +70,8 @@ public class PageCandidate { return columnOffset; } - public String getSdiHosts() { - return sdiHosts; + public String getpushDownHosts() { + return pushDownHosts; } public String getFileFormat() { @@ -82,4 +85,8 @@ public class PageCandidate { public int getTaskTimeout() { return taskTimeout; } + + public boolean isOperatorCombineEnabled() { + return isOperatorCombineEnabled; + } } diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/PageToColumnar.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/PageToColumnar.java index 42e7bc1bdf9bfc7533ef5d253bbc2d3a1fb74eb5..7922e69c5e4564e25d76c5cf5b623ac1a6eaf02d 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/PageToColumnar.java +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/PageToColumnar.java @@ -19,64 +19,120 @@ package org.apache.spark.sql; import org.apache.spark.sql.catalyst.expressions.Attribute; +import org.apache.spark.sql.catalyst.util.CharVarcharUtils; import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StringType; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.unsafe.types.UTF8String; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.collection.JavaConverters; import scala.collection.Seq; import java.io.Serializable; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static io.prestosql.spi.type.VarcharType.UNBOUNDED_LENGTH; +import static org.apache.commons.lang.StringUtils.rightPad; /** * PageToColumnar */ public class PageToColumnar implements Serializable { - StructType structType = null; - Seq outPut = null; + private static final Logger LOG = LoggerFactory.getLogger(PageToColumnar.class); + + private static final String ORC_HIVE = "hive"; + + private static final String METADATA_CHAR = "char"; + + StructType structType; + Seq outPut; + public PageToColumnar(StructType structType, Seq outPut) { this.structType = structType; this.outPut = outPut; } public List transPageToColumnar(Iterator writableColumnVectors, - boolean isVectorizedReader) { - scala.collection.Iterator structFieldIterator = structType.iterator(); - List columnType = new ArrayList<>(); - - while (structFieldIterator.hasNext()) { - columnType.add(structFieldIterator.next().dataType()); + boolean isVectorizedReader, boolean isOperatorCombineEnabled, Seq sparkOutput, String orcImpl) { + if (isOperatorCombineEnabled) { + LOG.debug("OmniRuntime PushDown column info: OmniColumnVector transform to Columnar"); } List internalRowList = new ArrayList<>(); + List outputColumnList = JavaConverters.seqAsJavaList(sparkOutput); while (writableColumnVectors.hasNext()) { WritableColumnVector[] columnVector = writableColumnVectors.next(); if (columnVector == null) { continue; } int positionCount = columnVector[0].getElementsAppended(); - if (positionCount > 0) { - if (isVectorizedReader) { - ColumnarBatch columnarBatch = new ColumnarBatch(columnVector); - columnarBatch.setNumRows(positionCount); - internalRowList.add(columnarBatch); - } else { - for (int j = 0; j < positionCount; j++) { - MutableColumnarRow mutableColumnarRow = - new MutableColumnarRow(columnVector); - mutableColumnarRow.rowId = j; - internalRowList.add(mutableColumnarRow); + if (positionCount <= 0) { + continue; + } + if (isVectorizedReader) { + ColumnarBatch columnarBatch = new ColumnarBatch(columnVector); + columnarBatch.setNumRows(positionCount); + internalRowList.add(columnarBatch); + } else { + for (int j = 0; j < positionCount; j++) { + // when outputColumnList is empty, the output does not need to be processed. + if (!outputColumnList.isEmpty()) { + procVectorForOrcHive(columnVector, orcImpl, outputColumnList, j); } + MutableColumnarRow mutableColumnarRow = + new MutableColumnarRow(columnVector); + mutableColumnarRow.rowId = j; + internalRowList.add(mutableColumnarRow); } } } return internalRowList; } -} - - + public void procVectorForOrcHive(WritableColumnVector[] columnVectors, String orcImpl, List outputColumnList, int rowId) { + if (orcImpl.equals(ORC_HIVE)) { + for (int i = 0; i < columnVectors.length; i++) { + if (columnVectors[i].dataType() instanceof StringType) { + Attribute attribute = outputColumnList.get(i); + Metadata metadata = attribute.metadata(); + putPaddingChar(columnVectors[i], metadata, rowId); + } + } + } + } + private void putPaddingChar(WritableColumnVector columnVector, Metadata metadata, int rowId) { + if (CharVarcharUtils.getRawTypeString(metadata).isDefined()) { + String metadataStr = CharVarcharUtils.getRawTypeString(metadata).get(); + Pattern pattern = Pattern.compile("(?<=\\()\\d+(?=\\))"); + Matcher matcher = pattern.matcher(metadataStr); + String len = String.valueOf(UNBOUNDED_LENGTH); + while(matcher.find()){ + len = matcher.group(); + } + if (metadataStr.startsWith(METADATA_CHAR)) { + UTF8String utf8String = columnVector.getUTF8String(rowId); + if (utf8String == null) { + return; + } + String vecStr = utf8String.toString(); + String vecStrPad = rightPad(vecStr, Integer.parseInt(len), ' '); + byte[] bytes = vecStrPad.getBytes(StandardCharsets.UTF_8); + if (columnVector instanceof OnHeapColumnVector) { + columnVector.putByteArray(rowId, bytes, 0, bytes.length); + } else { + columnVector.putBytes(rowId, bytes.length, bytes, 0); + } + } + } + } +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/PushDownManager.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/PushDownManager.java index a1278adabfbe329973535b61ea959119fa282131..75d7b1cc0b1fb06f45a6f5f2c4b765d798a90b50 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/PushDownManager.java +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/PushDownManager.java @@ -47,9 +47,8 @@ public class PushDownManager { private static final int ZOOKEEPER_RETRY_INTERVAL_MS = 1000; - public scala.collection.Map getZookeeperData( - int timeOut, String parentPath, String zkAddress) throws Exception { - Map fpuMap = new HashMap<>(); + public scala.collection.Map getZookeeperData( + int timeOut, String parentPath, String zkAddress) throws Exception { CuratorFramework zkClient = CuratorFrameworkFactory.builder() .connectString(zkAddress) .sessionTimeoutMs(timeOut) @@ -67,12 +66,11 @@ public class PushDownManager { if (!path.contains("-lock-")) { byte[] data = zkClient.getData().forPath(parentPath + "/" + path); PushDownData statusInfo = mapper.readValue(data, PushDownData.class); - fpuMap.put(path, statusInfo.getDatanodeHost()); pushDownInfoMap.put(path, statusInfo); } } if (checkAllPushDown(pushDownInfoMap)) { - return javaMapToScala(fpuMap); + return javaMapToScala(pushDownInfoMap); } else { return javaMapToScala(new HashMap<>()); } @@ -110,11 +108,11 @@ public class PushDownManager { return true; } - private static scala.collection.Map javaMapToScala(Map kafkaParams) { + private static scala.collection.Map javaMapToScala(Map kafkaParams) { scala.collection.Map scalaMap = JavaConverters.mapAsScalaMap(kafkaParams); Object objTest = Map$.MODULE$.newBuilder().$plus$plus$eq(scalaMap.toSeq()); Object resultTest = ((scala.collection.mutable.Builder) objTest).result(); - scala.collection.Map retMap = (scala.collection.Map) resultTest; + scala.collection.Map retMap = (scala.collection.Map) resultTest; return retMap; } -} +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRadixRowSorter.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRadixRowSorter.java new file mode 100644 index 0000000000000000000000000000000000000000..cfbd86de018fcb2866536aae05418acc4fe7cc41 --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRadixRowSorter.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.IOException; +import java.util.List; +import java.util.function.Supplier; + +import org.apache.spark.util.collection.unsafe.sort.*; +import scala.collection.Iterator; +import scala.math.Ordering; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; +import org.apache.spark.internal.config.package$; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.Platform; + +public class UnsafeExternalRadixRowSorter { + + /** + * If positive, forces records to be spilled to disk at the given frequency (measured in numbers + * of records). This is only intended to be used in tests. + */ + private int testSpillFrequency = 0; + + private long numRowsInserted = 0; + + private final StructType schema; + private final List prefixComputers; + private final UnsafeExternalRadixSorter sorter; + + // This flag makes sure the cleanupResource() has been called. After the cleanup work, + // iterator.next should always return false. Downstream operator triggers the resource + // cleanup while they found there's no need to keep the iterator any more. + // See more details in SPARK-21492. + private boolean isReleased = false; + + public abstract static class PrefixComputer { + + public static class Prefix { + /** + * Key prefix value, or the null prefix value if isNull = true. + **/ + public long value; + + /** + * Whether the key is null. + */ + public boolean isNull; + } + + /** + * Computes prefix for the given row. For efficiency, the returned object may be reused in + * further calls to a given PrefixComputer. + */ + public abstract Prefix computePrefix(InternalRow row); + } + + public static UnsafeExternalRadixRowSorter create( + StructType schema, + Ordering ordering, + List prefixComparators, + List prefixComputers, + long pageSizeBytes, + boolean canUseRadixSort) throws IOException { + // 空的列比较器 + Supplier recordComparatorSupplier = + () -> new RowComparator(ordering, schema.length()); + return new UnsafeExternalRadixRowSorter(schema, recordComparatorSupplier, prefixComparators, + prefixComputers, pageSizeBytes, canUseRadixSort); + } + + private UnsafeExternalRadixRowSorter( + StructType schema, + Supplier recordComparatorSupplier, + List prefixComparators, + List prefixComputers, + long pageSizeBytes, + boolean canUseRadixSort) { + this.schema = schema; + this.prefixComputers = prefixComputers; + final SparkEnv sparkEnv = SparkEnv.get(); + final TaskContext taskContext = TaskContext.get(); + sorter = UnsafeExternalRadixSorter.create( + taskContext.taskMemoryManager(), + sparkEnv.blockManager(), + sparkEnv.serializerManager(), + taskContext, + recordComparatorSupplier, + prefixComparators, + (int) (long) sparkEnv.conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()), + pageSizeBytes, + (int) SparkEnv.get().conf().get( + package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), + canUseRadixSort + ); + } + + /** + * Forces spills to occur every `frequency` records. Only for use in tests. + */ + @VisibleForTesting + public void setTestSpillFrequency(int frequency) { + assert frequency > 0 : "Frequency must be positive"; + testSpillFrequency = frequency; + } + + public void insertRow(UnsafeRow row) throws IOException { + final PrefixComputer.Prefix prefix1 = prefixComputers.get(0).computePrefix(row); + final PrefixComputer.Prefix prefix2 = prefixComputers.get(1).computePrefix(row); + sorter.insertRecord( + row.getBaseObject(), + row.getBaseOffset(), + row.getSizeInBytes(), + prefix1.value, + prefix1.isNull, + prefix2.value, + prefix2.isNull + ); + numRowsInserted++; + if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) { + sorter.spill(); + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsage() { + return sorter.getPeakMemoryUsedBytes(); + } + + /** + * @return the total amount of time spent sorting data (in-memory only). + */ + public long getSortTimeNanos() { + return sorter.getSortTimeNanos(); + } + + public void cleanupResources() { + isReleased = true; + sorter.cleanupResources(); + } + + public Iterator sort() throws IOException { + try { + final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); + if (!sortedIterator.hasNext()) { + // Since we won't ever call next() on an empty iterator, we need to clean up resources + // here in order to prevent memory leaks. + cleanupResources(); + } + return new RowIterator() { + + private final int numFields = schema.length(); + private UnsafeRow row = new UnsafeRow(numFields); + + @Override + public boolean advanceNext() { + try { + if (!isReleased && sortedIterator.hasNext()) { + sortedIterator.loadNext(); + row.pointTo( + sortedIterator.getBaseObject(), + sortedIterator.getBaseOffset(), + sortedIterator.getRecordLength()); + // Here is the initial bug fix in SPARK-9364: the bug fix of use-after-free bug + // when returning the last row from an iterator. For example, in + // [[GroupedIterator]], we still use the last row after traversing the iterator + // in `fetchNextGroupIterator` + if (!sortedIterator.hasNext()) { + row = row.copy(); // so that we don't have dangling pointers to freed page + cleanupResources(); + } + return true; + } else { + row = null; // so that we don't keep references to the base object + return false; + } + } catch (IOException e) { + cleanupResources(); + // Scala iterators don't declare any checked exceptions, so we need to use this hack + // to re-throw the exception: + Platform.throwException(e); + } + throw new RuntimeException("Exception should have been re-thrown in next()"); + } + + @Override + public UnsafeRow getRow() { + return row; + } + + }.toScala(); + } catch (IOException e) { + cleanupResources(); + throw e; + } + } + + public Iterator sort(Iterator inputIterator) throws IOException { + while (inputIterator.hasNext()) { + insertRow(inputIterator.next()); + } + return sort(); + } + + private static final class RowComparator extends RecordComparator { + private final Ordering ordering; + private final UnsafeRow row1; + private final UnsafeRow row2; + + RowComparator(Ordering ordering, int numFields) { + this.row1 = new UnsafeRow(numFields); + this.row2 = new UnsafeRow(numFields); + this.ordering = ordering; + } + + @Override + public int compare( + Object baseObj1, + long baseOff1, + int baseLen1, + Object baseObj2, + long baseOff2, + int baseLen2) { + // Note that since ordering doesn't need the total length of the record, we just pass 0 + // into the row. + row1.pointTo(baseObj1, baseOff1, 0); + row2.pointTo(baseObj2, baseOff2, 0); + return ordering.compare(row1, row2); + } + } +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixTwoColumnSort.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixTwoColumnSort.java new file mode 100644 index 0000000000000000000000000000000000000000..aaf30a44eacb32df6acc06428db6b4e4df306edc --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixTwoColumnSort.java @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; + +public class RadixTwoColumnSort { + + /** + * Sorts a given array of longs using least-significant-digit radix sort. This routine assumes + * you have extra space at the end of the array at least equal to the number of records. The + * sort is destructive and may relocate the data positioned within the array. + * + * @param array array of long elements followed by at least that many empty slots. + * @param numRecords number of data records in the array. + * @param startByteIndex the first byte (in range [0, 7]) to sort each long by, counting from the + * least significant byte. + * @param endByteIndex the last byte (in range [0, 7]) to sort each long by, counting from the + * least significant byte. Must be greater than startByteIndex. + * @param desc whether this is a descending (binary-order) sort. + * @param signed whether this is a signed (two's complement) sort. + * @return The starting index of the sorted data within the given array. We return this instead + * of always copying the data back to position zero for efficiency. + */ + public static int sort( + LongArray array, long numRecords, int startByteIndex, int endByteIndex, + boolean desc, boolean signed) { + assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0"; + assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; + assert endByteIndex > startByteIndex; + assert numRecords * 2 <= array.size(); + long inIndex = 0; + long outIndex = numRecords; + if (numRecords > 0) { + long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex); + for (int i = startByteIndex; i <= endByteIndex; i++) { + if (counts[i] != null) { + sortAtByte( + array, numRecords, counts[i], i, inIndex, outIndex, + desc, signed && i == endByteIndex); + long tmp = inIndex; + inIndex = outIndex; + outIndex = tmp; + } + } + } + return Ints.checkedCast(inIndex); + } + + /** + * Performs a partial sort by copying data into destination offsets for each byte value at the + * specified byte offset. + * + * @param array array to partially sort. + * @param numRecords number of data records in the array. + * @param counts counts for each byte value. This routine destructively modifies this array. + * @param byteIdx the byte in a long to sort at, counting from the least significant byte. + * @param inIndex the starting index in the array where input data is located. + * @param outIndex the starting index where sorted output data should be written. + * @param desc whether this is a descending (binary-order) sort. + * @param signed whether this is a signed (two's complement) sort (only applies to last byte). + */ + private static void sortAtByte( + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, + boolean desc, boolean signed) { + assert counts.length == 256; + long[] offsets = transformCountsToOffsets( + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed); + Object baseObject = array.getBaseObject(); + long baseOffset = array.getBaseOffset() + inIndex * 8L; + long maxOffset = baseOffset + numRecords * 8L; + for (long offset = baseOffset; offset < maxOffset; offset += 8) { + long value = Platform.getLong(baseObject, offset); + int bucket = (int) ((value >>> (byteIdx * 8)) & 0xff); + Platform.putLong(baseObject, offsets[bucket], value); + offsets[bucket] += 8; + } + } + + /** + * Computes a value histogram for each byte in the given array. + * + * @param array array to count records in. + * @param numRecords number of data records in the array. + * @param startByteIndex the first byte to compute counts for (the prior are skipped). + * @param endByteIndex the last byte to compute counts for. + * @return an array of eight 256-byte count arrays, one for each byte starting from the least + * significant byte. If the byte does not need sorting the array will be null. + */ + private static long[][] getCounts( + LongArray array, long numRecords, int startByteIndex, int endByteIndex) { + long[][] counts = new long[8][]; + // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting. + // If all the byte values at a particular index are the same we don't need to count it. + long bitwiseMax = 0; + long bitwiseMin = -1L; + long maxOffset = array.getBaseOffset() + numRecords * 8L; + Object baseObject = array.getBaseObject(); + for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) { + long value = Platform.getLong(baseObject, offset); + bitwiseMax |= value; + bitwiseMin &= value; + } + long bitsChanged = bitwiseMin ^ bitwiseMax; + // Compute counts for each byte index. + for (int i = startByteIndex; i <= endByteIndex; i++) { + if (((bitsChanged >>> (i * 8)) & 0xff) != 0) { + counts[i] = new long[256]; + // TODO(ekl) consider computing all the counts in one pass. + for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) { + counts[i][(int) ((Platform.getLong(baseObject, offset) >>> (i * 8)) & 0xff)]++; + } + } + } + return counts; + } + + /** + * Transforms counts into the proper unsafe output offsets for the sort type. + * + * @param counts counts for each byte value. This routine destructively modifies this array. + * @param numRecords number of data records in the original data array. + * @param outputOffset output offset in bytes from the base array object. + * @param bytesPerRecord size of each record (8 for plain sort, 16 for key-prefix sort). + * @param desc whether this is a descending (binary-order) sort. + * @param signed whether this is a signed (two's complement) sort. + * @return the input counts array. + */ + private static long[] transformCountsToOffsets( + long[] counts, long numRecords, long outputOffset, long bytesPerRecord, + boolean desc, boolean signed) { + assert counts.length == 256; + // signed 的情况,bigint都是unsigned + int start = signed ? 128 : 0; // output the negative records first (values 129-255). + if (desc) { + long pos = numRecords; + for (int i = start; i < start + 256; i++) { + pos -= counts[i & 0xff]; + counts[i & 0xff] = outputOffset + pos * bytesPerRecord; + } + } else { + long pos = 0; + // 256个桶,遍历每个桶 + for (int i = start; i < start + 256; i++) { + // 是否有记录落在这个桶里 + long tmp = counts[i & 0xff]; + // 更换counts里的值为落在这个桶里的数据要放到LongArray中那个位置 + // outputOffset是原始数据在LongArray中的结束位置 + // 如果counts[0]有3个,count[1]有1个;bytesPerRecord=16 + // 变成位置信息之后,counts[0]=0,counts[1]=16*3,counts[2]=16*4 + counts[i & 0xff] = outputOffset + pos * bytesPerRecord; + pos += tmp; + } + } + return counts; + } + + /** + * Specialization of sort() for key-prefix arrays. In this type of array, each record consists + * of two longs, only the second of which is sorted on. + * + * @param startIndex starting index in the array to sort from. This parameter is not supported + * in the plain sort() implementation. + */ + public static int sortKeyPrefixArray( + LongArray array, + long startIndex, + long numRecords, + int startByteIndex, + int endByteIndex, + boolean desc, + boolean signed, + int prefixShiftOffset) { + // radix 排的是long值,按字节比较,long总共8个字节,从低字节startByteIndex开始比。 + assert numRecords * 6 <= array.size(); + // 在LongArray中的第0个long + long inIndex = startIndex; + // 在LongArray中的最后1个long + long outIndex = startIndex + numRecords * 3L; + if (numRecords > 0) { + // 修改3,下面是通过单个prefix进行排序的过程,进行两次,先后各排一次 + // long[8][256] + // 按每一个字节位(最多8个字节位)比,落到每个桶(256个桶)里面的记录数 + long[][] counts = getKeyPrefixArrayCounts( + array, startIndex, numRecords, startByteIndex, endByteIndex, prefixShiftOffset); + // 遍历每一个字节位,然后将counts转换成分配到这个桶里的记录要被放到LongArray中的那个位置(利用LongArray中剩余的位置来做为存放数据的桶) + for (int i = startByteIndex; i <= 7; i++) { + if (counts[i] != null) { + // 将counts转换成分配到这个桶里的记录要被放到LongArray中的那个位置(利用LongArray中剩余的位置来做为存放数据的桶) + sortKeyPrefixArrayAtByte2( + array, numRecords, counts[i], i, inIndex, outIndex, + desc, signed && i == endByteIndex); + long tmp = inIndex; + inIndex = outIndex; + outIndex = tmp; + } + } + for (int i = 8; i <= endByteIndex; i++) { + if (counts[i] != null) { + // 将counts转换成分配到这个桶里的记录要被放到LongArray中的那个位置(利用LongArray中剩余的位置来做为存放数据的桶) + sortKeyPrefixArrayAtByte( + array, numRecords, counts[i], i, inIndex, outIndex, + desc, signed && i == endByteIndex); + long tmp = inIndex; + inIndex = outIndex; + outIndex = tmp; + } + } + } + return Ints.checkedCast(inIndex); + } + + /** + * Specialization of getCounts() for key-prefix arrays. We could probably combine this with + * getCounts with some added parameters but that seems to hurt in benchmarks. + */ + private static long[][] getKeyPrefixArrayCounts( + LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex, int prefixShiftOffset) { + long[][] counts = new long[16][]; + long bitwiseMax1 = 0; + long bitwiseMin1 = -1L; + // 获取第0条记录的内存位置 + long baseOffset = array.getBaseOffset() + startIndex * 8L; + // 最后一条记录的结束内存位置 + long limit = baseOffset + numRecords * 24L; + Object baseObject = array.getBaseObject(); + // 遍历所有记录的prefix,得出每一位的不同 + for (long offset = baseOffset; offset < limit; offset += 24) { + long value = Platform.getLong(baseObject, offset + 8); + bitwiseMax1 |= value; + bitwiseMin1 &= value; + } + long bitsChanged1 = bitwiseMin1 ^ bitwiseMax1; + // 从第0个字节位开始,到第7个字节位。 + // 遍历所有的记录,统计在当前字节位,每个桶里会有几条记录 + for (int i = 0; i <= 7; i++) { + if (((bitsChanged1 >>> (i * 8)) & 0xff) != 0) { + counts[i + 8] = new long[256]; + for (long offset = baseOffset; offset < limit; offset += 24) { + counts[i + 8][(int) ((Platform.getLong(baseObject, offset + 8) >>> (i * 8)) & 0xff)]++; + } + } + } + + + long bitwiseMax2 = 0; + long bitwiseMin2 = -1L; + // 遍历所有记录的prefix,得出每一位的不同 + for (long offset = baseOffset; offset < limit; offset += 24) { + long value = Platform.getLong(baseObject, offset + 16); + bitwiseMax2 |= value; + bitwiseMin2 &= value; + } + long bitsChanged2 = bitwiseMin2 ^ bitwiseMax2; + // 从第0个字节位开始,到第7个字节位。 + // 遍历所有的记录,统计在当前字节位,每个桶里会有几条记录 + for (int i = 0; i <= 7; i++) { + if (((bitsChanged2 >>> (i * 8)) & 0xff) != 0) { + counts[i] = new long[256]; + for (long offset = baseOffset; offset < limit; offset += 24) { + counts[i][(int) ((Platform.getLong(baseObject, offset + 16) >>> (i * 8)) & 0xff)]++; + } + } + } + + + return counts; + } + + /** + * Specialization of sortAtByte() for key-prefix arrays. + */ + private static void sortKeyPrefixArrayAtByte( + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, + boolean desc, boolean signed) { + assert counts.length == 256; + // 将counts转换成分配到这个桶里的记录要被放到LongArray中的那个位置(利用LongArray中剩余的位置来做为存放数据的桶) + long[] offsets = transformCountsToOffsets( + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 24, desc, signed); + Object baseObject = array.getBaseObject(); + // 第一条记录的起始位置 + long baseOffset = array.getBaseOffset() + inIndex * 8L; + // 最后一条记录的结束位置 + long maxOffset = baseOffset + numRecords * 24L; + // 遍历每一条记录 + for (long offset = baseOffset; offset < maxOffset; offset += 24) { + // 记录的指针 + long key = Platform.getLong(baseObject, offset); + // 记录的前缀 + long prefix1 = Platform.getLong(baseObject, offset + 8); + long prefix2 = Platform.getLong(baseObject, offset + 16); + // 计算在当前字节位,记录应该落在哪个桶里 + int bucket = (int) ((prefix1 >>> (byteIdx * 8)) & 0xff); + // 获取到记录应该放到哪个位置 + long dest = offsets[bucket]; + // 放记录指针 + Platform.putLong(baseObject, dest, key); + // 放记录前缀 + Platform.putLong(baseObject, dest + 8, prefix1); + Platform.putLong(baseObject, dest + 16, prefix2); + // 落在这个桶里的可能还有其他记录,位置偏移2个long + offsets[bucket] += 24; + } + } + + private static void sortKeyPrefixArrayAtByte2( + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, + boolean desc, boolean signed) { + assert counts.length == 256; + // 将counts转换成分配到这个桶里的记录要被放到LongArray中的那个位置(利用LongArray中剩余的位置来做为存放数据的桶) + long[] offsets = transformCountsToOffsets( + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 24, desc, signed); + Object baseObject = array.getBaseObject(); + // 第一条记录的起始位置 + long baseOffset = array.getBaseOffset() + inIndex * 8L; + // 最后一条记录的结束位置 + long maxOffset = baseOffset + numRecords * 24L; + // 遍历每一条记录 + for (long offset = baseOffset; offset < maxOffset; offset += 24) { + // 记录的指针 + long key = Platform.getLong(baseObject, offset); + // 记录的前缀 + long prefix1 = Platform.getLong(baseObject, offset + 8); + long prefix2 = Platform.getLong(baseObject, offset + 16); + // 计算在当前字节位,记录应该落在哪个桶里 + int bucket = (int) ((prefix2 >>> (byteIdx * 8)) & 0xff); + // 获取到记录应该放到哪个位置 + long dest = offsets[bucket]; + // 放记录指针 + Platform.putLong(baseObject, dest, key); + // 放记录前缀 + Platform.putLong(baseObject, dest + 8, prefix1); + Platform.putLong(baseObject, dest + 16, prefix2); + // 落在这个桶里的可能还有其他记录,位置偏移2个long + offsets[bucket] += 24; + } + } + +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalRadixSorter.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalRadixSorter.java new file mode 100644 index 0000000000000000000000000000000000000000..da43ad10905ae3bbccf44fb5586c624b5045de06 --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalRadixSorter.java @@ -0,0 +1,796 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import javax.annotation.Nullable; +import java.io.File; +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; +import java.util.function.Supplier; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.spark.memory.SparkOutOfMemoryError; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TooLargePageException; +import org.apache.spark.serializer.SerializerManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.util.Utils; + +/** + * External sorter based on {@link UnsafeInMemoryRadixSorter}. + */ +public final class UnsafeExternalRadixSorter extends MemoryConsumer { + + private static final Logger logger = LoggerFactory.getLogger(UnsafeExternalRadixSorter.class); + + @Nullable + private final List prefixComparators; + + /** + * {@link RecordComparator} may probably keep the reference to the records they compared last + * time, so we should not keep a {@link RecordComparator} instance inside + * {@link UnsafeExternalRadixSorter}, because {@link UnsafeExternalRadixSorter} is referenced by + * {@link TaskContext} and thus can not be garbage collected until the end of the task. + */ + @Nullable + private final Supplier recordComparatorSupplier; + + private final TaskMemoryManager taskMemoryManager; + private final BlockManager blockManager; + private final SerializerManager serializerManager; + private final TaskContext taskContext; + + /** + * The buffer size to use when writing spills using DiskBlockObjectWriter + */ + private final int fileBufferSizeBytes; + + /** + * Force this sorter to spill when there are this many elements in memory. + */ + private final int numElementsForSpillThreshold; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList<>(); + + private final LinkedList spillWriters = new LinkedList<>(); + + // These variables are reset after spilling: + @Nullable + private volatile UnsafeInMemoryRadixSorter inMemSorter; + + private MemoryBlock currentPage = null; + private long pageCursor = -1; + private long peakMemoryUsedBytes = 0; + private long totalSpillBytes = 0L; + private long totalSortTimeNanos = 0L; + private volatile SpillableIterator readingIterator = null; + + public static UnsafeExternalRadixSorter createWithExistingInMemorySorter( + TaskMemoryManager taskMemoryManager, + BlockManager blockManager, + SerializerManager serializerManager, + TaskContext taskContext, + Supplier recordComparatorSupplier, + List prefixComparators, + int initialSize, + long pageSizeBytes, + int numElementsForSpillThreshold, + UnsafeInMemoryRadixSorter inMemorySorter, + long existingMemoryConsumption) throws IOException { + UnsafeExternalRadixSorter sorter = new UnsafeExternalRadixSorter(taskMemoryManager, blockManager, + serializerManager, taskContext, recordComparatorSupplier, prefixComparators, initialSize, + pageSizeBytes, numElementsForSpillThreshold, inMemorySorter, false /* ignored */); + sorter.spill(Long.MAX_VALUE, sorter); + taskContext.taskMetrics().incMemoryBytesSpilled(existingMemoryConsumption); + sorter.totalSpillBytes += existingMemoryConsumption; + // The external sorter will be used to insert records, in-memory sorter is not needed. + sorter.inMemSorter = null; + return sorter; + } + + public static UnsafeExternalRadixSorter create( + TaskMemoryManager taskMemoryManager, + BlockManager blockManager, + SerializerManager serializerManager, + TaskContext taskContext, + Supplier recordComparatorSupplier, + List prefixComparators, + int initialSize, + long pageSizeBytes, + int numElementsForSpillThreshold, + boolean canUseRadixSort) { + return new UnsafeExternalRadixSorter(taskMemoryManager, blockManager, serializerManager, + taskContext, recordComparatorSupplier, prefixComparators, initialSize, pageSizeBytes, + numElementsForSpillThreshold, null, canUseRadixSort); + } + + private UnsafeExternalRadixSorter( + TaskMemoryManager taskMemoryManager, + BlockManager blockManager, + SerializerManager serializerManager, + TaskContext taskContext, + Supplier recordComparatorSupplier, + List prefixComparators, + int initialSize, + long pageSizeBytes, + int numElementsForSpillThreshold, + @Nullable UnsafeInMemoryRadixSorter existingInMemorySorter, + boolean canUseRadixSort) { + super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode()); + this.taskMemoryManager = taskMemoryManager; + this.blockManager = blockManager; + this.serializerManager = serializerManager; + this.taskContext = taskContext; + this.recordComparatorSupplier = recordComparatorSupplier; + this.prefixComparators = prefixComparators; + // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units + // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024 + this.fileBufferSizeBytes = 32 * 1024; + + if (existingInMemorySorter == null) { + RecordComparator comparator = null; + if (recordComparatorSupplier != null) { + comparator = recordComparatorSupplier.get(); + } + this.inMemSorter = new UnsafeInMemoryRadixSorter( + this, + taskMemoryManager, + comparator, + prefixComparators, + initialSize, + canUseRadixSort); + } else { + this.inMemSorter = existingInMemorySorter; + } + this.peakMemoryUsedBytes = getMemoryUsage(); + this.numElementsForSpillThreshold = numElementsForSpillThreshold; + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the sorter's output (e.g. sort followed by limit). + taskContext.addTaskCompletionListener(context -> { + cleanupResources(); + }); + } + + /** + * Marks the current page as no-more-space-available, and as a result, either allocate a + * new page or spill when we see the next record. + */ + @VisibleForTesting + public void closeCurrentPage() { + if (currentPage != null) { + pageCursor = currentPage.getBaseOffset() + currentPage.size(); + } + } + + /** + * Sort and spill the current records in response to memory pressure. + */ + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + if (trigger != this) { + if (readingIterator != null) { + return readingIterator.spill(); + } + return 0L; // this should throw exception + } + + if (inMemSorter == null || inMemSorter.numRecords() <= 0) { + // There could still be some memory allocated when there are no records in the in-memory + // sorter. We will not spill it however, to ensure that we can always process at least one + // record before spilling. See the comments in `allocateMemoryForRecordIfNecessary` for why + // this is necessary. + return 0L; + } + + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spillWriters.size(), + spillWriters.size() > 1 ? " times" : " time"); + + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); + + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, + inMemSorter.numRecords()); + spillWriters.add(spillWriter); + spillIterator(inMemSorter.getSortedIterator(), spillWriter); + + final long spillSize = freeMemory(); + // Note that this is more-or-less going to be a multiple of the page size, so wasted space in + // pages will currently be counted as memory spilled even though that space isn't actually + // written to disk. This also counts the space needed to store the sorter's pointer array. + inMemSorter.freeMemory(); + // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the + // records. Otherwise, if the task is over allocated memory, then without freeing the memory + // pages, we might not be able to get memory for the pointer array. + + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); + totalSpillBytes += spillSize; + return spillSize; + } + + /** + * Return the total memory usage of this sorter, including the data pages and the sorter's pointer + * array. + */ + private long getMemoryUsage() { + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + + /** + * @return the total amount of time spent sorting data (in-memory only). + */ + public long getSortTimeNanos() { + UnsafeInMemoryRadixSorter sorter = inMemSorter; + if (sorter != null) { + return sorter.getSortTimeNanos(); + } + return totalSortTimeNanos; + } + + /** + * Return the total number of bytes that has been spilled into disk so far. + */ + public long getSpillSize() { + return totalSpillBytes; + } + + @VisibleForTesting + public int getNumberOfAllocatedPages() { + return allocatedPages.size(); + } + + /** + * Free this sorter's data pages. + * + * @return the number of bytes freed. + */ + private long freeMemory() { + updatePeakMemoryUsed(); + long memoryFreed = 0; + for (MemoryBlock block : allocatedPages) { + memoryFreed += block.size(); + freePage(block); + } + allocatedPages.clear(); + currentPage = null; + pageCursor = 0; + return memoryFreed; + } + + /** + * Deletes any spill files created by this sorter. + */ + private void deleteSpillFiles() { + for (UnsafeSorterSpillWriter spill : spillWriters) { + File file = spill.getFile(); + if (file != null && file.exists()) { + if (!file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } + } + } + + /** + * Frees this sorter's in-memory data structures and cleans up its spill files. + */ + public void cleanupResources() { + synchronized (this) { + deleteSpillFiles(); + freeMemory(); + if (inMemSorter != null) { + inMemSorter.freeMemory(); + inMemSorter = null; + } + } + } + + /** + * Checks whether there is enough space to insert an additional record in to the sort pointer + * array and grows the array if additional space is required. If the required space cannot be + * obtained, then the in-memory data will be spilled to disk. + */ + private void growPointerArrayIfNecessary() throws IOException { + assert (inMemSorter != null); + if (!inMemSorter.hasSpaceForAnotherRecord()) { + if (inMemSorter.numRecords() <= 0) { + // Spilling was triggered just before this method was called. The pointer array was freed + // during the spill, so a new pointer array needs to be allocated here. + LongArray array = allocateArray(inMemSorter.getInitialSize()); + inMemSorter.expandPointerArray(array); + return; + } + + long used = inMemSorter.getMemoryUsage(); + LongArray array = null; + try { + // could trigger spilling + array = allocateArray(used / 8 * 2); + } catch (TooLargePageException e) { + // The pointer array is too big to fix in a single page, spill. + spill(); + } catch (SparkOutOfMemoryError e) { + if (inMemSorter.numRecords() > 0) { + logger.error("Unable to grow the pointer array"); + throw e; + } + // The new array could not be allocated, but that is not an issue as it is longer needed, + // as all records were spilled. + } + + if (inMemSorter.numRecords() <= 0) { + // Spilling was triggered while trying to allocate the new array. + if (array != null) { + // We succeeded in allocating the new array, but, since all records were spilled, a + // smaller array would also suffice. + freeArray(array); + } + // The pointer array was freed during the spill, so a new pointer array needs to be + // allocated here. + array = allocateArray(inMemSorter.getInitialSize()); + } + inMemSorter.expandPointerArray(array); + } + } + + /** + * Allocates an additional page in order to insert an additional record. This will request + * additional memory from the memory manager and spill if the requested memory can not be + * obtained. + * + * @param required the required space in the data page, in bytes, including space for storing + * the record size. + */ + private void acquireNewPageIfNecessary(int required) { + if (currentPage == null || + pageCursor + required > currentPage.getBaseOffset() + currentPage.size()) { + // TODO: try to find space on previous pages + currentPage = allocatePage(required); + pageCursor = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); + } + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the memory manager and spill if the requested memory can not be obtained. + * + * @param required the required space in the data page, in bytes, including space for storing + * the record size. + */ + private void allocateMemoryForRecordIfNecessary(int required) throws IOException { + // Step 1: + // Ensure that the pointer array has space for another record. This may cause a spill. + growPointerArrayIfNecessary(); + // Step 2: + // Ensure that the last page has space for another record. This may cause a spill. + acquireNewPageIfNecessary(required); + // Step 3: + // The allocation in step 2 could have caused a spill, which would have freed the pointer + // array allocated in step 1. Therefore we need to check again whether we have to allocate + // a new pointer array. + // + // If the allocation in this step causes a spill event then it will not cause the page + // allocated in the previous step to be freed. The function `spill` only frees memory if at + // least one record has been inserted in the in-memory sorter. This will not be the case if + // we have spilled in the previous step. + // + // If we did not spill in the previous step then `growPointerArrayIfNecessary` will be a + // no-op that does not allocate any memory, and therefore can't cause a spill event. + // + // Thus there is no need to call `acquireNewPageIfNecessary` again after this step. + growPointerArrayIfNecessary(); + } + + /** + * Write a record to the sorter. + */ + public void insertRecord( + Object recordBase, long recordOffset, int length, + long prefix1, boolean prefix1IsNull, + long prefix2, boolean prefix2IsNull) + throws IOException { + + assert (inMemSorter != null); + if (inMemSorter.numRecords() >= numElementsForSpillThreshold) { + logger.info("Spilling data because number of spilledRecords crossed the threshold " + + numElementsForSpillThreshold); + spill(); + } + + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + // Need 4 or 8 bytes to store the record length. + final int required = length + uaoSize; + allocateMemoryForRecordIfNecessary(required); + + final Object base = currentPage.getBaseObject(); + final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); + UnsafeAlignedOffset.putSize(base, pageCursor, length); + pageCursor += uaoSize; + Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); + pageCursor += length; + inMemSorter.insertRecord(recordAddress, prefix1, prefix1IsNull, prefix2, prefix2IsNull); + } + + /** + * Write a key-value record to the sorter. The key and value will be put together in-memory, + * using the following format: + *

+ * record length (4 bytes), key length (4 bytes), key data, value data + *

+ * record length = key length + value length + 4 + */ + + + /** + * Merges another UnsafeExternalRadixSorters into this one, the other one will be emptied. + */ + public void merge(UnsafeExternalRadixSorter other) throws IOException { + other.spill(); + totalSpillBytes += other.totalSpillBytes; + spillWriters.addAll(other.spillWriters); + // remove them from `spillWriters`, or the files will be deleted in `cleanupResources`. + other.spillWriters.clear(); + other.cleanupResources(); + } + + /** + * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()` + * after consuming this iterator. + */ + public UnsafeSorterIterator getSortedIterator() throws IOException { + assert (recordComparatorSupplier != null); + if (spillWriters.isEmpty()) { + assert (inMemSorter != null); + readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); + return readingIterator; + } else { + final UnsafeRadixSorterSpillMerger spillMerger = new UnsafeRadixSorterSpillMerger( + recordComparatorSupplier.get(), prefixComparators, spillWriters.size()); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager)); + } + if (inMemSorter != null) { + readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); + spillMerger.addSpillIfNotEmpty(readingIterator); + } + return spillMerger.getSortedIterator(); + } + } + + @VisibleForTesting + boolean hasSpaceForAnotherRecord() { + return inMemSorter.hasSpaceForAnotherRecord(); + } + + private static void spillIterator(UnsafeSorterIterator inMemIterator, + UnsafeSorterSpillWriter spillWriter) throws IOException { + while (inMemIterator.hasNext()) { + inMemIterator.loadNext(); + final Object baseObject = inMemIterator.getBaseObject(); + final long baseOffset = inMemIterator.getBaseOffset(); + final int recordLength = inMemIterator.getRecordLength(); + spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix()); + } + spillWriter.close(); + } + + /** + * An UnsafeSorterIterator that support spilling. + */ + class SpillableIterator extends UnsafeSorterIterator { + private UnsafeSorterIterator upstream; + private MemoryBlock lastPage = null; + private boolean loaded = false; + private int numRecords; + + private Object currentBaseObject; + private long currentBaseOffset; + private int currentRecordLength; + private long currentKeyPrefix; + + SpillableIterator(UnsafeSorterIterator inMemIterator) { + this.upstream = inMemIterator; + this.numRecords = inMemIterator.getNumRecords(); + } + + @Override + public int getNumRecords() { + return numRecords; + } + + @Override + public long getCurrentPageNumber() { + throw new UnsupportedOperationException(); + } + + public long spill() throws IOException { + synchronized (this) { + if (inMemSorter == null) { + return 0L; + } + + long currentPageNumber = upstream.getCurrentPageNumber(); + + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); + if (numRecords > 0) { + // Iterate over the records that have not been returned and spill them. + final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter( + blockManager, fileBufferSizeBytes, writeMetrics, numRecords); + spillIterator(upstream, spillWriter); + spillWriters.add(spillWriter); + upstream = spillWriter.getReader(serializerManager); + } else { + // Nothing to spill as all records have been read already, but do not return yet, as the + // memory still has to be freed. + upstream = null; + } + + long released = 0L; + synchronized (UnsafeExternalRadixSorter.this) { + // release the pages except the one that is used. There can still be a caller that + // is accessing the current record. We free this page in that caller's next loadNext() + // call. + for (MemoryBlock page : allocatedPages) { + if (!loaded || page.pageNumber != currentPageNumber) { + released += page.size(); + freePage(page); + } else { + lastPage = page; + } + } + allocatedPages.clear(); + if (lastPage != null) { + // Add the last page back to the list of allocated pages to make sure it gets freed in + // case loadNext() never gets called again. + allocatedPages.add(lastPage); + } + } + + // in-memory sorter will not be used after spilling + assert (inMemSorter != null); + released += inMemSorter.getMemoryUsage(); + totalSortTimeNanos += inMemSorter.getSortTimeNanos(); + inMemSorter.freeMemory(); + inMemSorter = null; + taskContext.taskMetrics().incMemoryBytesSpilled(released); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); + totalSpillBytes += released; + return released; + } + } + + @Override + public boolean hasNext() { + return numRecords > 0; + } + + @Override + public void loadNext() throws IOException { + assert upstream != null; + MemoryBlock pageToFree = null; + try { + synchronized (this) { + loaded = true; + // Just consumed the last record from the in-memory iterator. + if (lastPage != null) { + // Do not free the page here, while we are locking `SpillableIterator`. The `freePage` + // method locks the `TaskMemoryManager`, and it's a bad idea to lock 2 objects in + // sequence. We may hit dead lock if another thread locks `TaskMemoryManager` and + // `SpillableIterator` in sequence, which may happen in + // `TaskMemoryManager.acquireExecutionMemory`. + pageToFree = lastPage; + allocatedPages.clear(); + lastPage = null; + } + numRecords--; + upstream.loadNext(); + + // Keep track of the current base object, base offset, record length, and key prefix, + // so that the current record can still be read in case a spill is triggered and we + // switch to the spill writer's iterator. + currentBaseObject = upstream.getBaseObject(); + currentBaseOffset = upstream.getBaseOffset(); + currentRecordLength = upstream.getRecordLength(); + currentKeyPrefix = upstream.getKeyPrefix(); + } + } finally { + if (pageToFree != null) { + freePage(pageToFree); + } + } + } + + @Override + public Object getBaseObject() { + return currentBaseObject; + } + + @Override + public long getBaseOffset() { + return currentBaseOffset; + } + + @Override + public int getRecordLength() { + return currentRecordLength; + } + + @Override + public long getKeyPrefix() { + return currentKeyPrefix; + } + } + + /** + * Returns an iterator starts from startIndex, which will return the rows in the order as + * inserted. + *

+ * It is the caller's responsibility to call `cleanupResources()` + * after consuming this iterator. + *

+ * TODO: support forced spilling + */ + public UnsafeSorterIterator getIterator(int startIndex) throws IOException { + if (spillWriters.isEmpty()) { + assert (inMemSorter != null); + UnsafeSorterIterator iter = inMemSorter.getSortedIterator(); + moveOver(iter, startIndex); + return iter; + } else { + LinkedList queue = new LinkedList<>(); + int i = 0; + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + if (i + spillWriter.recordsSpilled() > startIndex) { + UnsafeSorterIterator iter = spillWriter.getReader(serializerManager); + moveOver(iter, startIndex - i); + queue.add(iter); + } + i += spillWriter.recordsSpilled(); + } + if (inMemSorter != null && inMemSorter.numRecords() > 0) { + UnsafeSorterIterator iter = inMemSorter.getSortedIterator(); + moveOver(iter, startIndex - i); + queue.add(iter); + } + return new ChainedIterator(queue); + } + } + + private void moveOver(UnsafeSorterIterator iter, int steps) + throws IOException { + if (steps > 0) { + for (int i = 0; i < steps; i++) { + if (iter.hasNext()) { + iter.loadNext(); + } else { + throw new ArrayIndexOutOfBoundsException("Failed to move the iterator " + steps + + " steps forward"); + } + } + } + } + + /** + * Chain multiple UnsafeSorterIterator together as single one. + */ + static class ChainedIterator extends UnsafeSorterIterator { + + private final Queue iterators; + private UnsafeSorterIterator current; + private int numRecords; + + ChainedIterator(Queue iterators) { + assert iterators.size() > 0; + this.numRecords = 0; + for (UnsafeSorterIterator iter : iterators) { + this.numRecords += iter.getNumRecords(); + } + this.iterators = iterators; + this.current = iterators.remove(); + } + + @Override + public int getNumRecords() { + return numRecords; + } + + @Override + public long getCurrentPageNumber() { + return current.getCurrentPageNumber(); + } + + @Override + public boolean hasNext() { + while (!current.hasNext() && !iterators.isEmpty()) { + current = iterators.remove(); + } + return current.hasNext(); + } + + @Override + public void loadNext() throws IOException { + while (!current.hasNext() && !iterators.isEmpty()) { + current = iterators.remove(); + } + current.loadNext(); + } + + @Override + public Object getBaseObject() { + return current.getBaseObject(); + } + + @Override + public long getBaseOffset() { + return current.getBaseOffset(); + } + + @Override + public int getRecordLength() { + return current.getRecordLength(); + } + + @Override + public long getKeyPrefix() { + return current.getKeyPrefix(); + } + } +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemoryRadixSorter.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemoryRadixSorter.java new file mode 100644 index 0000000000000000000000000000000000000000..1004aee3e695e43c2a9a34ff94d95ed0b7e0c97f --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemoryRadixSorter.java @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.util.LinkedList; +import java.util.List; +import java.util.stream.Collectors; + +import javax.annotation.Nullable; + +import org.apache.spark.TaskContext; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.SparkOutOfMemoryError; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.LongArray; + +/** + * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records + * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm + * compares records, it will first compare the stored key prefixes; if the prefixes are not equal, + * then we do not need to traverse the record pointers to compare the actual records. Avoiding these + * random memory accesses improves cache hit rates. + */ +public final class UnsafeInMemoryRadixSorter { + + private final MemoryConsumer consumer; + private final TaskMemoryManager memoryManager; + + /** + * If non-null, specifies the radix sort parameters and that radix sort will be used. + */ + @Nullable + private final List radixSortSupports; + + /** + * Within this buffer, position {@code 2 * i} holds a pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + *

+ * Only part of the array will be used to store the pointers, the rest part is preserved as + * temporary buffer for sorting. + */ + private LongArray array; + + /** + * The position in the sort buffer where new records can be inserted. + */ + private int pos = 0; + + /** + * If sorting with radix sort, specifies the starting position in the sort buffer where records + * with non-null prefixes are kept. Positions [0..nullBoundaryPos) will contain null-prefixed + * records, and positions [nullBoundaryPos..pos) non-null prefixed records. This lets us avoid + * radix sorting over null values. + */ + private int nullBoundaryPos = 0; + + /* + * How many records could be inserted, because part of the array should be left for sorting. + */ + private int usableCapacity = 0; + + private long initialSize; + + private long totalSortTimeNanos = 0L; + + public UnsafeInMemoryRadixSorter( + final MemoryConsumer consumer, + final TaskMemoryManager memoryManager, + final RecordComparator recordComparator, + final List prefixComparators, + int initialSize, + boolean canUseRadixSort) { + this(consumer, memoryManager, recordComparator, prefixComparators, + consumer.allocateArray(initialSize * 2L), canUseRadixSort); + } + + public UnsafeInMemoryRadixSorter( + final MemoryConsumer consumer, + final TaskMemoryManager memoryManager, + final RecordComparator recordComparator, + final List prefixComparators, + LongArray array, + boolean canUseRadixSort) { + this.consumer = consumer; + this.memoryManager = memoryManager; + this.initialSize = array.size(); + if (recordComparator != null) { + if (canUseRadixSort) { + this.radixSortSupports = prefixComparators.stream() + .map(prefixComparator -> ((PrefixComparators.RadixSortSupport) prefixComparator)) + .collect(Collectors.toList()); + } else { + this.radixSortSupports = null; + } + } else { + this.radixSortSupports = null; + } + this.array = array; + this.usableCapacity = getUsableCapacity(); + } + + private int getUsableCapacity() { + // Radix sort requires same amount of used memory as buffer, Tim sort requires + // half of the used memory as buffer. + return (int) (array.size() / (radixSortSupports != null ? 2 : 1.5)); + } + + public long getInitialSize() { + return initialSize; + } + + /** + * Free the memory used by pointer array. + */ + public void freeMemory() { + if (consumer != null) { + if (array != null) { + consumer.freeArray(array); + } + + // Set the array to null instead of allocating a new array. Allocating an array could have + // triggered another spill and this method already is called from UnsafeExternalSorter when + // spilling. Attempting to allocate while spilling is dangerous, as we could be holding onto + // a large partially complete allocation, which may prevent other memory from being allocated. + // Instead we will allocate the new array when it is necessary. + array = null; + usableCapacity = 0; + } + pos = 0; + nullBoundaryPos = 0; + } + + /** + * @return the number of records that have been inserted into this sorter. + */ + public int numRecords() { + return pos / 3; + } + + /** + * @return the total amount of time spent sorting data (in-memory only). + */ + public long getSortTimeNanos() { + return totalSortTimeNanos; + } + + public long getMemoryUsage() { + if (array == null) { + return 0L; + } + + return array.size() * 8; + } + + public boolean hasSpaceForAnotherRecord() { + return pos + 1 < usableCapacity; + } + + public void expandPointerArray(LongArray newArray) { + if (array != null) { + if (newArray.size() < array.size()) { + // checkstyle.off: RegexpSinglelineJava + throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); + // checkstyle.on: RegexpSinglelineJava + } + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + newArray.getBaseObject(), + newArray.getBaseOffset(), + pos * 8L); + consumer.freeArray(array); + } + array = newArray; + usableCapacity = getUsableCapacity(); + } + + /** + * Inserts a record to be sorted. Assumes that the record pointer points to a record length + * stored as a uaoSize(4 or 8) bytes integer, followed by the record's bytes. + * + * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}. + * @param keyPrefix1 a user-defined key prefix + */ + public void insertRecord(long recordPointer, + long keyPrefix1, boolean prefix1IsNull, + long keyPrefix2, boolean prefix2IsNull) { + if (!hasSpaceForAnotherRecord()) { + throw new IllegalStateException("There is no space for new record"); + } + assert radixSortSupports != null; + boolean prefix2Desc = radixSortSupports.get(1).sortDescending(); + if (prefix1IsNull) { + // Swap forward a non-null record to make room for this one at the beginning of the array. + array.set(pos, array.get(nullBoundaryPos)); + pos++; + array.set(pos, array.get(nullBoundaryPos + 1)); + pos++; + array.set(pos, array.get(nullBoundaryPos + 2)); + pos++; + + // Place this record in the vacated position. + array.set(nullBoundaryPos, recordPointer); + nullBoundaryPos++; + array.set(nullBoundaryPos, keyPrefix1); + nullBoundaryPos++; + // prefix2是null的情况 + if (prefix2Desc) { + array.set(nullBoundaryPos, Long.MAX_VALUE - keyPrefix2); + } else { + array.set(nullBoundaryPos, keyPrefix2); + } + nullBoundaryPos++; + } else { + // 行记录位置 + array.set(pos, recordPointer); + pos++; + // 修改2,前缀,这里放的时候需要放2个 + array.set(pos, keyPrefix1); + pos++; + if (prefix2Desc) { + array.set(pos, Long.MAX_VALUE - keyPrefix2); + } else { + array.set(pos, keyPrefix2); + } + pos++; + } + } + + public final class SortedIterator extends UnsafeSorterIterator implements Cloneable { + + private final int numRecords; + private int position; + private int offset; + private Object baseObject; + private long baseOffset; + private long keyPrefix; + + private long keyPrefix2; + private int recordLength; + private long currentPageNumber; + private final TaskContext taskContext = TaskContext.get(); + + private SortedIterator(int numRecords, int offset) { + this.numRecords = numRecords; + this.position = 0; + this.offset = offset; + } + + @Override + public SortedIterator clone() { + SortedIterator iter = new SortedIterator(numRecords, offset); + iter.position = position; + iter.baseObject = baseObject; + iter.baseOffset = baseOffset; + iter.keyPrefix = keyPrefix; + iter.recordLength = recordLength; + iter.currentPageNumber = currentPageNumber; + return iter; + } + + @Override + public int getNumRecords() { + return numRecords; + } + + @Override + public boolean hasNext() { + return position / 3 < numRecords; + } + + @Override + public void loadNext() { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in `loadNext()` instead of in + // `hasNext()` because it's technically possible for the caller to be relying on + // `getNumRecords()` instead of `hasNext()` to know when to stop. + if (taskContext != null) { + taskContext.killTaskIfInterrupted(); + } + // This pointer points to a 4-byte record length, followed by the record's bytes + final long recordPointer = array.get(offset + position); + currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + baseObject = memoryManager.getPage(recordPointer); + // Skip over record length + baseOffset = memoryManager.getOffsetInPage(recordPointer) + uaoSize; + recordLength = UnsafeAlignedOffset.getSize(baseObject, baseOffset - uaoSize); + + keyPrefix = array.get(offset + position + 1); + keyPrefix2 = array.get(offset + position + 2); + position += 3; + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return baseOffset; + } + + @Override + public long getCurrentPageNumber() { + return currentPageNumber; + } + + @Override + public int getRecordLength() { + return recordLength; + } + + @Override + public long getKeyPrefix() { + return keyPrefix; + } + } + + /** + * Return an iterator over record pointers in sorted order. For efficiency, all calls to + * {@code next()} will return the same mutable object. + */ + public UnsafeSorterIterator getSortedIterator() { + if (numRecords() == 0) { + // `array` might be null, so make sure that it is not accessed by returning early. + return new SortedIterator(0, 0); + } + + int offset = 0; + long start = System.nanoTime(); + if (this.radixSortSupports != null) { + // 拿到排完序后的数据在数组中的结束位置 + offset = RadixTwoColumnSort.sortKeyPrefixArray( + array, nullBoundaryPos, (pos - nullBoundaryPos) / 3L, 0, 15, + radixSortSupports.get(0).sortDescending(), radixSortSupports.get(0).sortSigned(), 8); + } + + totalSortTimeNanos += System.nanoTime() - start; + if (nullBoundaryPos > 0) { + assert radixSortSupports != null : "Nulls are only stored separately with radix sort"; + LinkedList queue = new LinkedList<>(); + + // The null order is either LAST or FIRST, regardless of sorting direction (ASC|DESC) + + if (radixSortSupports.get(0).nullsFirst()) { + queue.add(new SortedIterator(nullBoundaryPos / 3, 0)); + queue.add(new SortedIterator((pos - nullBoundaryPos) / 3, offset)); + } else { + queue.add(new SortedIterator((pos - nullBoundaryPos) / 3, offset)); + queue.add(new SortedIterator(nullBoundaryPos / 3, 0)); + } + return new UnsafeExternalSorter.ChainedIterator(queue); + } else { + return new SortedIterator(pos / 3, offset); + } + } +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeRadixSorterSpillMerger.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeRadixSorterSpillMerger.java new file mode 100644 index 0000000000000000000000000000000000000000..11d8cbf83ceba88b40648d0228dea8733dd6750b --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeRadixSorterSpillMerger.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; +import java.util.Comparator; +import java.util.List; +import java.util.PriorityQueue; + +final class UnsafeRadixSorterSpillMerger { + + private int numRecords = 0; + private final PriorityQueue priorityQueue; + + UnsafeRadixSorterSpillMerger( + RecordComparator recordComparator, + List prefixComparators, + int numSpills) { + Comparator comparator = (left, right) -> { + int prefixComparisonResult = + prefixComparators.get(0).compare(left.getKeyPrefix(), right.getKeyPrefix()); + if (prefixComparisonResult == 0) { + return recordComparator.compare( + left.getBaseObject(), left.getBaseOffset(), left.getRecordLength(), + right.getBaseObject(), right.getBaseOffset(), right.getRecordLength()); + } else { + return prefixComparisonResult; + } + }; + priorityQueue = new PriorityQueue<>(numSpills, comparator); + } + + /** + * Add an UnsafeSorterIterator to this merger + */ + public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOException { + if (spillReader.hasNext()) { + // We only add the spillReader to the priorityQueue if it is not empty. We do this to + // make sure the hasNext method of UnsafeSorterIterator returned by getSortedIterator + // does not return wrong result because hasNext will return true + // at least priorityQueue.size() times. If we allow n spillReaders in the + // priorityQueue, we will have n extra empty records in the result of UnsafeSorterIterator. + spillReader.loadNext(); + priorityQueue.add(spillReader); + numRecords += spillReader.getNumRecords(); + } + } + + public UnsafeSorterIterator getSortedIterator() throws IOException { + return new UnsafeSorterIterator() { + + private UnsafeSorterIterator spillReader; + + @Override + public int getNumRecords() { + return numRecords; + } + + @Override + public long getCurrentPageNumber() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean hasNext() { + return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); + } + + @Override + public void loadNext() throws IOException { + if (spillReader != null) { + if (spillReader.hasNext()) { + spillReader.loadNext(); + priorityQueue.add(spillReader); + } + } + spillReader = priorityQueue.remove(); + } + + @Override + public Object getBaseObject() { + return spillReader.getBaseObject(); + } + + @Override + public long getBaseOffset() { + return spillReader.getBaseOffset(); + } + + @Override + public int getRecordLength() { + return spillReader.getRecordLength(); + } + + @Override + public long getKeyPrefix() { + return spillReader.getKeyPrefix(); + } + }; + } +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/com/huawei/boostkit/omnioffload/spark/ColumnarPlugin.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/com/huawei/boostkit/omnioffload/spark/ColumnarPlugin.scala new file mode 100644 index 0000000000000000000000000000000000000000..e1a9fdce53a5ee2ff1dc75b78eb6439618748bbf --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/com/huawei/boostkit/omnioffload/spark/ColumnarPlugin.scala @@ -0,0 +1,735 @@ +/* + * Copyright (C) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.omnioffload.spark + +import com.huawei.boostkit.omnidata.spark.NdpConnectorUtils +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation} +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.{Inner, LeftAnti, LeftOuter} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommandExec} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.ndp.NdpConf +import org.apache.spark.sql.execution.ndp.NdpConf.getOptimizerPushDownThreshold +import org.apache.spark.sql.hive.execution.CreateHiveTableAsSelectCommand +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataTypes, DoubleType, LongType} +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} +import java.net.URI +import scala.collection.JavaConverters + +case class NdpOverrides(sparkSession: SparkSession) extends Rule[SparkPlan] { + + var numPartitions: Int = -1 + var pushDownTaskCount: Int = -1 + var isSMJ = false + var isSort = false + var hasCoalesce = false + var hasShuffle = false + var ACCURATE_QUERY_HD = "153" + var RADIX_SORT_COLUMN_NUMS = 2 + + def apply(plan: SparkPlan): SparkPlan = { + preRuleApply(plan) + val ruleList = Seq(CountReplaceRule) + val afterPlan = ruleList.foldLeft(plan) { case (sp, rule) => + val result = rule.apply(sp) + result + } + val operatorEnable = NdpConf.getNdpOperatorCombineEnabled(sparkSession) + val optimizedPlan = if (operatorEnable) { + replaceWithOptimizedPlan(afterPlan) + } else { + replaceWithOptimizedPlanNoOperator(afterPlan) + } + val finalPlan = replaceWithScanPlan(optimizedPlan) + postRuleApply(finalPlan) + finalPlan + } + + def preRuleApply(plan: SparkPlan): Unit = { + numPartitions = SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.coalesce.numPartitions", + NdpConnectorUtils.getNdpNumPartitionsStr("10000")).toInt + pushDownTaskCount = NdpConnectorUtils.getPushDownTaskTotal(getOptimizerPushDownThreshold(sparkSession)) + if (CountReplaceRule.shouldReplaceCountOne(plan)) { + pushDownTaskCount = NdpConnectorUtils.getCountTaskTotal(50) + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, + NdpConnectorUtils.getCountMaxPartSize("512MB")) + } + if (CountReplaceRule.shouldReplaceDistinctCount(plan)) { + pushDownTaskCount = NdpConnectorUtils.getCountDistinctTaskTotal(50) + } + } + + def postRuleApply(plan: SparkPlan): Unit = { + if (isSMJ) { + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, + NdpConnectorUtils.getSMJMaxPartSize("536870912")) + } + } + + //now set task total number, we can use this number pushDown task in thread + def replaceWithScanPlan(plan: SparkPlan): SparkPlan = { + val p = plan.transformUp { + case scan: FileSourceScanExec => + scan.setRuntimePushDownSum(pushDownTaskCount) + if (hasCoalesce && !hasShuffle) { + // without shuffle , coalesce num is task num + scan.setRuntimePartSum(numPartitions) + } + scan + case p => p + } + p + } + + def replaceWithOptimizedPlan(plan: SparkPlan): SparkPlan = { + val p = plan.transformUp { + case shuffle: ShuffleExchangeExec => + hasShuffle = true + shuffle + case p@ColumnarSortExec(sortOrder, global, child, testSpillFrequency) if isRadixSortExecEnable(sortOrder) => + isSort = true + RadixSortExec(sortOrder, global, child, testSpillFrequency) + case p@SortExec(sortOrder, global, child, testSpillFrequency) if isRadixSortExecEnable(sortOrder) => + isSort = true + RadixSortExec(sortOrder, global, child, testSpillFrequency) + case p@DataWritingCommandExec(cmd, child) => + if (isSort || isVagueAndAccurateHd(child)) { + p + } else { + hasCoalesce = true + DataWritingCommandExec(cmd, CoalesceExec(numPartitions, child)) + } + case p@ColumnarSortMergeJoinExec(_, _, joinType, _, _, _, _, projectList) + if joinType.equals(LeftOuter) && isTenPocJoin(p.leftKeys) && isTenPocJoin(p.rightKeys) => + isSMJ = true + numPartitions = NdpConnectorUtils.getSMJNumPartitions(5000) + ColumnarSortMergeJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, joinType = LeftAnti, + condition = p.condition, left = p.left, right = p.right, isSkewJoin = p.isSkewJoin, projectList) + case p@SortMergeJoinExec(_, _, joinType, _, _, _, _) + if joinType.equals(LeftOuter) && isTenPocJoin(p.leftKeys) && isTenPocJoin(p.rightKeys) => + isSMJ = true + numPartitions = NdpConnectorUtils.getSMJNumPartitions(5000) + SortMergeJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, joinType = LeftAnti, condition = p.condition, + left = p.left, right = p.right, isSkewJoin = p.isSkewJoin) + case p@ColumnarBroadcastHashJoinExec(_, _, joinType, _, _, _, _, _, projectList) if joinType.equals(LeftOuter) && isTenPocJoin(p.leftKeys) && isTenPocJoin(p.rightKeys) => + ColumnarBroadcastHashJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, + joinType = LeftAnti, buildSide = p.buildSide, condition = p.condition, left = p.left, + right = p.right, isNullAwareAntiJoin = p.isNullAwareAntiJoin, projectList) + case p@BroadcastHashJoinExec(_, _, joinType, _, _, _, _, _) if joinType.equals(LeftOuter) && isTenPocJoin(p.leftKeys) && isTenPocJoin(p.rightKeys) => + BroadcastHashJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, joinType = LeftAnti, + buildSide = p.buildSide, condition = p.condition, left = p.left, right = p.right, + isNullAwareAntiJoin = p.isNullAwareAntiJoin) + case p@ColumnarShuffledHashJoinExec(_, _, joinType, _, _, _, _, projectList) + if joinType.equals(LeftOuter) && isTenPocJoin(p.leftKeys) && isTenPocJoin(p.rightKeys) => + ColumnarShuffledHashJoinExec(p.leftKeys, p.rightKeys, LeftAnti, p.buildSide, p.condition, + p.left, p.right, projectList) + case p@ShuffledHashJoinExec(_, _, joinType, _, _, _, _) if joinType.equals(LeftOuter) && isTenPocJoin(p.leftKeys) && isTenPocJoin(p.rightKeys) => + ShuffledHashJoinExec(p.leftKeys, p.rightKeys, LeftAnti, p.buildSide, p.condition, p.left, p.right) + case p@FilterExec(condition, child: OmniColumnarToRowExec, selectivity) => + val childPlan = child.transform { + case p@OmniColumnarToRowExec(child: NdpFileSourceScanExec) => + ColumnarToRowExec(FileSourceScanExec(child.relation, + child.output, + child.requiredSchema, + child.partitionFilters, + child.optionalBucketSet, + child.optionalNumCoalescedBuckets, + child.dataFilters, + child.tableIdentifier, + child.partitionColumn, + child.disableBucketedScan)) + case p@OmniColumnarToRowExec(child: FileSourceScanExec) => + ColumnarToRowExec(child) + case p => p + } + FilterExec(condition, childPlan, selectivity) + case c1@OmniColumnarToRowExec(c2@ColumnarFilterExec(condition, c3: FileSourceScanExec)) => + numPartitions = NdpConnectorUtils.getOmniColumnarNumPartitions(1000) + if (NdpPluginEnableFlag.isAccurate(condition)) { + pushDownTaskCount = NdpConnectorUtils.getOmniColumnarTaskCount(50) + } + FilterExec(condition, ColumnarToRowExec(c3)) + case p@FilterExec(condition, _, _) if NdpPluginEnableFlag.isAccurate(condition) => + numPartitions = NdpConnectorUtils.getFilterPartitions(1000) + pushDownTaskCount = NdpConnectorUtils.getFilterTaskCount(50) + p + case p@ColumnarConditionProjectExec(projectList, condition, child) + if condition.toString().startsWith("isnull") && (child.isInstanceOf[ColumnarSortMergeJoinExec] + || child.isInstanceOf[ColumnarBroadcastHashJoinExec] || child.isInstanceOf[ColumnarShuffledHashJoinExec]) && isTenPocProject(projectList) => + ColumnarProjectExec(changeProjectList(projectList), child) + case p@ProjectExec(projectList, filter: FilterExec) + if filter.condition.toString().startsWith("isnull") && (filter.child.isInstanceOf[SortMergeJoinExec] + || filter.child.isInstanceOf[BroadcastHashJoinExec] || filter.child.isInstanceOf[ShuffledHashJoinExec]) && isTenPocProject(projectList) => + ProjectExec(changeProjectList(projectList), filter.child) + case p: SortAggregateExec if p.child.isInstanceOf[OmniColumnarToRowExec] + && p.child.asInstanceOf[OmniColumnarToRowExec].child.isInstanceOf[ColumnarSortExec] + && isAggPartial(p.aggregateAttributes) => + val omniColumnarToRow = p.child.asInstanceOf[OmniColumnarToRowExec] + val omniColumnarSort = omniColumnarToRow.child.asInstanceOf[ColumnarSortExec] + SortAggregateExec(p.requiredChildDistributionExpressions, + p.groupingExpressions, + p.aggregateExpressions, + p.aggregateAttributes, + p.initialInputBufferOffset, + p.resultExpressions, + SortExec(omniColumnarSort.sortOrder, + omniColumnarSort.global, + ColumnarToRowExec(omniColumnarSort.child), + omniColumnarSort.testSpillFrequency)) + case p: SortAggregateExec if p.child.isInstanceOf[OmniColumnarToRowExec] + && p.child.asInstanceOf[OmniColumnarToRowExec].child.isInstanceOf[ColumnarSortExec] + && isAggFinal(p.aggregateAttributes) => + val omniColumnarToRow = p.child.asInstanceOf[OmniColumnarToRowExec] + val omniColumnarSort = omniColumnarToRow.child.asInstanceOf[ColumnarSortExec] + val omniShuffleExchange = omniColumnarSort.child.asInstanceOf[ColumnarShuffleExchangeExec] + val rowToOmniColumnar = omniShuffleExchange.child.asInstanceOf[RowToOmniColumnarExec] + SortAggregateExec(p.requiredChildDistributionExpressions, + p.groupingExpressions, + p.aggregateExpressions, + p.aggregateAttributes, + p.initialInputBufferOffset, + p.resultExpressions, + SortExec(omniColumnarSort.sortOrder, + omniColumnarSort.global, + ShuffleExchangeExec(omniShuffleExchange.outputPartitioning, rowToOmniColumnar.child, + omniShuffleExchange.shuffleOrigin), + omniColumnarSort.testSpillFrequency)) + case p@OmniColumnarToRowExec(agg: ColumnarHashAggregateExec) + if agg.groupingExpressions.nonEmpty && agg.child.isInstanceOf[ColumnarShuffleExchangeExec] => + val omniExchange = agg.child.asInstanceOf[ColumnarShuffleExchangeExec] + val omniHashAgg = omniExchange.child.asInstanceOf[ColumnarHashAggregateExec] + HashAggregateExec(agg.requiredChildDistributionExpressions, + agg.groupingExpressions, + agg.aggregateExpressions, + agg.aggregateAttributes, + agg.initialInputBufferOffset, + agg.resultExpressions, + ShuffleExchangeExec(omniExchange.outputPartitioning, + HashAggregateExec(omniHashAgg.requiredChildDistributionExpressions, + omniHashAgg.groupingExpressions, + omniHashAgg.aggregateExpressions, + omniHashAgg.aggregateAttributes, + omniHashAgg.initialInputBufferOffset, + omniHashAgg.resultExpressions, + ColumnarToRowExec(omniHashAgg.child)), + omniExchange.shuffleOrigin)) + case p => p + } + p + } + + def replaceWithOptimizedPlanNoOperator(plan: SparkPlan): SparkPlan = { + val p = plan.transformUp { + case shuffle: ShuffleExchangeExec => + hasShuffle = true + shuffle + case p@SortExec(sortOrder, global, child, testSpillFrequency) if isRadixSortExecEnable(sortOrder) => + isSort = true + RadixSortExec(sortOrder, global, child, testSpillFrequency) + case p@DataWritingCommandExec(cmd, child) => + if (isSort || isVagueAndAccurateHd(child)) { + p + } else { + hasCoalesce = true + DataWritingCommandExec(cmd, CoalesceExec(numPartitions, child)) + } + case p@SortMergeJoinExec(_, _, joinType, _, _, _, _) + if joinType.equals(LeftOuter) && isTenPocJoin(p.leftKeys) && isTenPocJoin(p.rightKeys) => + isSMJ = true + numPartitions = NdpConnectorUtils.getSMJNumPartitions(5000) + SortMergeJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, joinType = LeftAnti, condition = p.condition, + left = p.left, right = p.right, isSkewJoin = p.isSkewJoin) + case p@BroadcastHashJoinExec(_, _, joinType, _, _, _, _, _) if joinType.equals(LeftOuter) && isTenPocJoin(p.leftKeys) && isTenPocJoin(p.rightKeys) => + BroadcastHashJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, joinType = LeftAnti, + buildSide = p.buildSide, condition = p.condition, left = p.left, right = p.right, + isNullAwareAntiJoin = p.isNullAwareAntiJoin) + case p@ShuffledHashJoinExec(_, _, joinType, _, _, _, _) if joinType.equals(LeftOuter) && isTenPocJoin(p.leftKeys) && isTenPocJoin(p.rightKeys) => + ShuffledHashJoinExec(p.leftKeys, p.rightKeys, LeftAnti, p.buildSide, p.condition, p.left, p.right) + case p@FilterExec(condition, _, _) if NdpPluginEnableFlag.isAccurate(condition) => + numPartitions = NdpConnectorUtils.getFilterPartitions(1000) + pushDownTaskCount = NdpConnectorUtils.getFilterTaskCount(50) + p + case p@ProjectExec(projectList, filter: FilterExec) + if filter.condition.toString().startsWith("isnull") && (filter.child.isInstanceOf[SortMergeJoinExec] + || filter.child.isInstanceOf[BroadcastHashJoinExec] || filter.child.isInstanceOf[ShuffledHashJoinExec]) && isTenPocProject(projectList) => + ProjectExec(changeProjectList(projectList), filter.child) + case p => p + } + p + } + + def isAggPartial(aggAttributes: Seq[Attribute]): Boolean = { + aggAttributes.exists(x => x.name.equals("max") || x.name.equals("maxxx")) + } + + def isAggFinal(aggAttributes: Seq[Attribute]): Boolean = { + aggAttributes.exists(x => x.name.contains("avg(cast")) + } + + def isVagueAndAccurateHd(child: SparkPlan): Boolean = { + var result = false + child match { + case filter: FilterExec => + filter.child match { + case columnarToRow: ColumnarToRowExec => + if (columnarToRow.child.isInstanceOf[FileSourceScanExec]) { + filter.condition.foreach { x => + if (x.isInstanceOf[StartsWith] || x.isInstanceOf[EndsWith] || x.isInstanceOf[Contains]) { + result = true + } + x match { + case literal: Literal if !literal.nullable && literal.value.toString.startsWith(ACCURATE_QUERY_HD) => + result = true + case _ => + } + } + } + case _ => + } + case _ => + } + result + } + + def changeProjectList(projectList: Seq[NamedExpression]): Seq[NamedExpression] = { + val p = projectList.map { + case exp: Alias => + Alias(Literal(null, exp.dataType), exp.name)( + exprId = exp.exprId, + qualifier = exp.qualifier, + explicitMetadata = exp.explicitMetadata, + nonInheritableMetadataKeys = exp.nonInheritableMetadataKeys + ) + case exp => exp + } + p + } + + def isRadixSortExecEnable(sortOrder: Seq[SortOrder]): Boolean = { + sortOrder.lengthCompare(RADIX_SORT_COLUMN_NUMS) == 0 && + sortOrder.head.dataType == LongType && + sortOrder.head.child.isInstanceOf[AttributeReference] && + sortOrder.head.child.asInstanceOf[AttributeReference].name.startsWith("col") && + sortOrder(1).dataType == LongType && + sortOrder(1).child.isInstanceOf[AttributeReference] && + sortOrder(1).child.asInstanceOf[AttributeReference].name.startsWith("col") && + SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.radixSort.enabled", "true").toBoolean + } + + def isTenPocProject(projectList: Seq[NamedExpression]): Boolean = { + projectList.forall { + case exp: Alias => + exp.child.isInstanceOf[AttributeReference] && exp.child.asInstanceOf[AttributeReference].name.startsWith("col") + case exp: AttributeReference => + exp.name.startsWith("col") + case _ => false + } + } + + def isTenPocJoin(keys: Seq[Expression]): Boolean = { + keys.forall { + case exp: AttributeReference => + exp.name.startsWith("col") + case _ => false + } + } +} + +case class NdpRules(session: SparkSession) extends ColumnarRule with Logging { + + def ndpOverrides: NdpOverrides = NdpOverrides(session) + + override def preColumnarTransitions: Rule[SparkPlan] = plan => { + plan + } + + override def postColumnarTransitions: Rule[SparkPlan] = plan => { + if (NdpPluginEnableFlag.isEnable(plan.sqlContext.sparkSession)) { + val rule = ndpOverrides + rule(plan) + } else { + plan + } + } + +} + +case class NdpOptimizerRules(session: SparkSession) extends Rule[LogicalPlan] { + + var maxSizeInBytes: BigInt = 0L + + var ACCURATE_QUERY = "000" + val SORT_REPARTITION_PLANS: Seq[String] = Seq( + "Sort,HiveTableRelation", + "Sort,LogicalRelation", + "Sort,RepartitionByExpression,HiveTableRelation", + "Sort,RepartitionByExpression,LogicalRelation", + "Sort,Project,HiveTableRelation", + "Sort,Project,LogicalRelation", + "Sort,RepartitionByExpression,Project,HiveTableRelation", + "Sort,RepartitionByExpression,Project,LogicalRelation" + ) + + val SORT_REPARTITION_SIZE: Int = SQLConf.get.getConfString( + "spark.omni.sql.ndpPlugin.sort.repartition.size", + NdpConnectorUtils.getSortRepartitionSizeStr("104857600")).toInt + val DECIMAL_PRECISION: Int = SQLConf.get.getConfString( + "spark.omni.sql.ndpPlugin.cast.decimal.precision", + NdpConnectorUtils.getCastDecimalPrecisionStr("15")).toInt + val MAX_PARTITION_BYTES_ENABLE_FACTOR: Int = SQLConf.get.getConfString( + "spark.omni.sql.ndpPlugin.max.partitionBytesEnable.factor", + NdpConnectorUtils.getNdpMaxPtFactorStr("2")).toInt + + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (NdpPluginEnableFlag.isEnable(session)) { + val res = replaceWithOptimizedPlan(plan) + repartition(FileSystem.get(session.sparkContext.hadoopConfiguration), plan) + res + } else if (NdpPluginEnableFlag.isNdpOptimizedEnable(session)) { + applyOptimizedRules(plan) + } else { + plan + } + } + + def applyOptimizedRules(plan: LogicalPlan): LogicalPlan = { + plan.foreach { + case p@LogicalRelation(_, _, catalogTable, _) => + val sTable = catalogTable.get.identifier + val stats = session.sessionState.catalog.getTableMetadata(sTable).stats + if (stats.isDefined) { + val sizeInBytes = stats.get.sizeInBytes + if (sizeInBytes > maxSizeInBytes) { + var fileMaxBytes = "512MB" + var shufflePartition = "200" + if (sizeInBytes <= 1073741824L) { + fileMaxBytes = NdpConnectorUtils.getMixSqlBaseMaxFilePtBytesStr("256MB") + shufflePartition = NdpConnectorUtils.getShufflePartitionsStr("200") + } else if (sizeInBytes > 1073741824L && sizeInBytes < 1099511627776L) { + fileMaxBytes = NdpConnectorUtils.getMixSqlBaseMaxFilePtBytesStr("256MB") + shufflePartition = NdpConnectorUtils.getShufflePartitionsStr("1000") + } else { + fileMaxBytes = NdpConnectorUtils.getMixSqlBaseMaxFilePtBytesStr("128MB") + shufflePartition = NdpConnectorUtils.getShufflePartitionsStr("1000") + } + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, fileMaxBytes) + SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, shufflePartition) + maxSizeInBytes = sizeInBytes + } + } + case _ => + } + plan + } + + def replaceWithOptimizedPlan(plan: LogicalPlan): LogicalPlan = { + plan.transformUp { + case CreateHiveTableAsSelectCommand(tableDesc, query, outputColumnNames, mode) + if isParquetEnable(tableDesc) + && checkParquetFieldNames(outputColumnNames) + && SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.parquetOutput.enabled", "true") + .toBoolean => + CreateDataSourceTableAsSelectCommand( + tableDesc.copy(provider = Option("parquet")), mode, query, outputColumnNames) + case a@Aggregate(groupingExpressions, aggregateExpressions, _) + if SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.castDecimal.enabled", "true") + .toBoolean => + var ifCast = false + if (groupingExpressions.nonEmpty && hasCount(aggregateExpressions)) { + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, + NdpConnectorUtils.getCountAggMaxFilePtBytesStr("1024MB")) + } else if (groupingExpressions.nonEmpty && hasAvg(aggregateExpressions)) { + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, + NdpConnectorUtils.getAvgAggMaxFilePtBytesStr("256MB")) + ifCast = true + } + if (ifCast) { + a.copy(aggregateExpressions = aggregateExpressions + .map(castSumAvgToBigInt) + .map(_.asInstanceOf[NamedExpression])) + } + else { + a + } + case j@Join(_, _, Inner, condition, _) => + // turnOffOperator() + // 6-x-bhj + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, + NdpConnectorUtils.getBhjMaxFilePtBytesStr("512MB")) + j + case s@Sort(order, _, _) => + s.copy(order = order.map(e => e.copy(child = castStringExpressionToBigint(e.child)))) + case p => p + } + } + + def hasCount(aggregateExpressions: Seq[Expression]): Boolean = { + aggregateExpressions.exists { + case exp: Alias if (exp.child.isInstanceOf[AggregateExpression] + && exp.child.asInstanceOf[AggregateExpression].aggregateFunction.isInstanceOf[Count]) => true + case _ => false + } + } + + def hasAvg(aggregateExpressions: Seq[Expression]): Boolean = { + aggregateExpressions.exists { + case exp: Alias if (exp.child.isInstanceOf[AggregateExpression] + && exp.child.asInstanceOf[AggregateExpression].aggregateFunction.isInstanceOf[Average]) => true + case _ => false + } + } + + def isParquetEnable(tableDesc: CatalogTable): Boolean = { + if (tableDesc.provider.isEmpty || tableDesc.provider.get.equals("hive")) { + if (tableDesc.storage.outputFormat.isEmpty + || tableDesc.storage.serde.get.equals("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) { + return true + } + } + false + } + + // ,;{}()\n\t= and space are special characters in Parquet schema + def checkParquetFieldNames(outputColumnNames: Seq[String]): Boolean = { + outputColumnNames.forall(!_.matches(".*[ ,;{}()\n\t=].*")) + } + + def repartition(fs: FileSystem, plan: LogicalPlan): Unit = { + var tables = Seq[URI]() + var planContents = Seq[String]() + var maxPartitionBytesEnable = true + var existsProject = false + var existsTable = false + var existsAgg = false + var existAccurate = false + var existFilter = false + var existJoin = false + var existLike = false + var isMixSql = false + + + plan.foreach { + case p@HiveTableRelation(tableMeta, _, _, _, _) => + if (tableMeta.storage.locationUri.isDefined) { + tables :+= tableMeta.storage.locationUri.get + } + existsTable = true + planContents :+= p.nodeName + case p@LogicalRelation(_, _, catalogTable, _) => + if (catalogTable.isDefined && catalogTable.get.storage.locationUri.isDefined) { + tables :+= catalogTable.get.storage.locationUri.get + } + existsTable = true + planContents :+= p.nodeName + case p: Project => + maxPartitionBytesEnable &= (p.output.length * MAX_PARTITION_BYTES_ENABLE_FACTOR < p.inputSet.size) + existsProject = true + planContents :+= p.nodeName + case p: Aggregate => + maxPartitionBytesEnable = true + existsProject = true + existsAgg = true + planContents :+= p.nodeName + case p@Filter(condition, _) => + existAccurate |= NdpPluginEnableFlag.isAccurate(condition) + existFilter = true + existLike |= isLike(condition) + planContents :+= p.nodeName + case p: Join => + existJoin = true + planContents :+= p.nodeName + case p => + planContents :+= p.nodeName + } + + if(!existsTable){ + return + } + + // mix sql + isMixSql = existJoin && existsAgg + if (isMixSql) { + if (existAccurate) { + SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, + NdpConnectorUtils.getAggShufflePartitionsStr("200")) + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, + NdpConnectorUtils.getMixSqlAccurateMaxFilePtBytesStr("1024MB")) + } else { + if (existLike) { + SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, + NdpConnectorUtils.getAggShufflePartitionsStr("200")) + } else { + SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, + NdpConnectorUtils.getShufflePartitionsStr("5000")) + } + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, + NdpConnectorUtils.getMixSqlBaseMaxFilePtBytesStr("128MB")) + } + // base sql agg shuffle partition 200 ,other 5000 + } else { + repartitionShuffleForSort(fs, tables, planContents) + repartitionHdfsReadForDistinct(fs, tables, plan) + if (existJoin) { + SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, + NdpConnectorUtils.getShufflePartitionsStr("5000")) + } + } + } + + def repartitionShuffleForSort(fs: FileSystem, tables: Seq[URI], planContents: Seq[String]): Unit = { + if (!SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.radixSort.enabled", "true").toBoolean) { + return + } + + val planContent = planContents.mkString(",") + if (tables.length == 1 + && SORT_REPARTITION_PLANS.exists(planContent.contains(_))) { + val partitions = Math.max(1, fs.getContentSummary(new Path(tables.head)).getLength / SORT_REPARTITION_SIZE) + SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, + NdpConnectorUtils.getSortShufflePartitionsStr("1000")) + turnOffOperator() + } + } + + def repartitionHdfsReadForDistinct(fs: FileSystem, tables: Seq[URI], plan: LogicalPlan): Unit = { + if (!SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.distinct.enabled", "true").toBoolean) { + return + } + if (tables.length != 1) { + return + } + + plan.foreach { + case Aggregate(groupingExpressions, aggregateExpressions, _) if groupingExpressions == aggregateExpressions => + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, + NdpConnectorUtils.getGroupMaxFilePtBytesStr("1024MB")) + return + case _ => + } + } + + def castSumAvgToBigInt(expression: Expression): Expression = { + val exp = expression.transform { + case Average(cast: Cast) if cast.dataType.isInstanceOf[DoubleType] + && cast.child.isInstanceOf[AttributeReference] + && cast.child.asInstanceOf[AttributeReference].name.startsWith("col")=> + Average(Cast(cast.child, DataTypes.LongType)) + case Sum(cast: Cast) if cast.dataType.isInstanceOf[DoubleType] + && cast.child.isInstanceOf[AttributeReference] + && cast.child.asInstanceOf[AttributeReference].name.startsWith("col")=> + Sum(Cast(cast.child, DataTypes.LongType)) + case e => + e + } + var finalExp = exp + exp match { + case agg: Alias if agg.child.isInstanceOf[AggregateExpression] + && agg.child.asInstanceOf[AggregateExpression].aggregateFunction.isInstanceOf[Sum] => + finalExp = Alias(Cast(agg.child, DataTypes.DoubleType), agg.name)( + exprId = agg.exprId, + qualifier = agg.qualifier, + explicitMetadata = agg.explicitMetadata, + nonInheritableMetadataKeys = agg.nonInheritableMetadataKeys + ) + case _ => + } + finalExp + } + + def castStringExpressionToBigint(expression: Expression): Expression = { + expression match { + case a@AttributeReference(_, DataTypes.StringType, _, _) if a.name.startsWith("col") => + Cast(a, DataTypes.LongType) + case e => e + } + } + + def turnOffOperator(): Unit = { + session.sqlContext.setConf("org.apache.spark.sql.columnar.enabled", "false") + session.sqlContext.setConf("spark.sql.join.columnar.preferShuffledHashJoin", "false") + } + + def isLike(condition: Expression): Boolean = { + var result = false + condition.foreach { + case _: StartsWith => + result = true + case _ => + } + result + } +} + +class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectColumnar(session => NdpRules(session)) + extensions.injectOptimizerRule(session => NdpOptimizerRules(session)) + } +} + +object NdpPluginEnableFlag { + val ndpEnabledStr = "spark.omni.sql.ndpPlugin.enabled" + var ACCURATE_QUERY = "000" + + def isAccurate(condition: Expression): Boolean = { + var result = false + condition.foreach { + // literal need to check null + case literal: Literal if !literal.nullable && literal.value.toString.startsWith(ACCURATE_QUERY) => + result = true + case _ => + } + result + } + + val ndpOptimizedEnableStr = "spark.omni.sql.ndpPlugin.optimized.enabled" + + def isMatchedIpAddress: Boolean = { + val ipSet = Set("xxx.xxx.xxx.xxx") + val hostAddrSet = JavaConverters.asScalaSetConverter(NdpConnectorUtils.getIpAddress).asScala + val res = ipSet & hostAddrSet + res.nonEmpty + } + + def isEnable(session: SparkSession): Boolean = { + def ndpEnabled: Boolean = session.sqlContext.getConf( + ndpEnabledStr, "false").trim.toBoolean + ndpEnabled && (isMatchedIpAddress || NdpConnectorUtils.getNdpEnable) + } + + def isEnable: Boolean = { + def ndpEnabled: Boolean = sys.props.getOrElse( + ndpEnabledStr, "false").trim.toBoolean + ndpEnabled && (isMatchedIpAddress || NdpConnectorUtils.getNdpEnable) + } +def isNdpOptimizedEnable(session: SparkSession): Boolean = { + session.sqlContext.getConf(ndpOptimizedEnableStr, "true").trim.toBoolean + } +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/com/huawei/boostkit/omnioffload/spark/CountReplaceRule.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/com/huawei/boostkit/omnioffload/spark/CountReplaceRule.scala new file mode 100644 index 0000000000000000000000000000000000000000..85719d375be00a9e372ae016a342b79553cd7b82 --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/com/huawei/boostkit/omnioffload/spark/CountReplaceRule.scala @@ -0,0 +1,176 @@ +/* + * Copyright (C) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.huawei.boostkit.omnioffload.spark + +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SimpleCountAggregateExec} +import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.{ColumnarToRowExec, FileSourceScanExec, SimpleCountFileScanExec, SparkPlan} + +object CountReplaceRule extends Rule[SparkPlan] { + var columnStat: BigInt = -1 + var isCountPlan: Boolean = false + + override def apply(plan: SparkPlan): SparkPlan = { + if (shouldReplaceDistinctCount(plan) || shouldReplaceCountOne(plan)) { + replaceCountPlan(plan) + } else { + plan + } + } + + def shouldReplaceCountOne(plan: SparkPlan): Boolean = { + plan match { + case DataWritingCommandExec(_, + HashAggregateExec(_, groups: Seq[NamedExpression], aggExps: Seq[AggregateExpression], _, _, _, + ShuffleExchangeExec(_, + HashAggregateExec(_, _, _, _, _, _, + ColumnarToRowExec( + scan: FileSourceScanExec)), _))) => + if (groups.nonEmpty) { + return false + } + if (aggExps.isEmpty) { + return false + } + val headAggExp = aggExps.head + if (!headAggExp.aggregateFunction.isInstanceOf[Count]) { + return false + } + val countFunc = headAggExp.aggregateFunction.asInstanceOf[Count] + val countChild = countFunc.children + if (!countChild.equals(Seq(Literal(1)))) { + return false + } + if (!scan.relation.fileFormat.isInstanceOf[ParquetFileFormat]) { + return false + } + val countTable = scan.tableIdentifier.get + val stats = plan.sqlContext.sparkSession.sessionState.catalog + .getTableMetadata(countTable).stats + if (stats.isEmpty) { + return false + } + val countValue = stats.get.rowCount + if (countValue.isEmpty) { + return false + } + columnStat = countValue.get + isCountPlan = true + true + case _ => false + } + } + + def shouldReplaceDistinctCount(plan: SparkPlan): Boolean = { + plan match { + case DataWritingCommandExec(_, + topFinalAgg@HashAggregateExec(_, _, _, _, _, _, + ShuffleExchangeExec(_, + HashAggregateExec(_, _, _, _, _, _, + HashAggregateExec(_, _, _, _, _, _, + ShuffleExchangeExec(_, + HashAggregateExec(_, _, _, _, _, _, + ColumnarToRowExec( + scanExec: FileSourceScanExec)), _))), _))) => + if (topFinalAgg.groupingExpressions.nonEmpty) { + return false + } + val aggExps = topFinalAgg.aggregateExpressions + if (aggExps.size != 1) { + return false + } + val headAggExp = aggExps.head + if (!headAggExp.isDistinct) { + return false + } + if (!headAggExp.aggregateFunction.isInstanceOf[Count]) { + return false + } + val countFunc = headAggExp.aggregateFunction.asInstanceOf[Count] + val countChild = countFunc.children + if (countChild.size != 1) { + return false + } + if (!countChild.head.isInstanceOf[AttributeReference]) { + return false + } + val distinctColumn = scanExec.schema.head.name + val distinctTable = scanExec.tableIdentifier.get + + val stats = plan.sqlContext.sparkSession.sessionState.catalog + .getTableMetadata(distinctTable).stats + if (stats.isEmpty) { + return false + } + val colStatsMap = stats.map(_.colStats).getOrElse(Map.empty) + if (colStatsMap.isEmpty) { + return false + } + if (colStatsMap(distinctColumn) == null) { + return false + } + columnStat = colStatsMap(distinctColumn).distinctCount.get + true + case _ => false + } + } + + def replaceCountPlan(plan: SparkPlan): SparkPlan = plan match { + case scan: FileSourceScanExec if isCountPlan => + SimpleCountFileScanExec(scan.relation, + scan.output, + scan.requiredSchema, + scan.partitionFilters, + scan.optionalBucketSet, + scan.optionalNumCoalescedBuckets, + scan.dataFilters, + scan.tableIdentifier, + scan.disableBucketedScan, + isEmptyIter = true) + case agg@HashAggregateExec(_, _, _, _, _, _, shuffle: ShuffleExchangeExec) if isCountPlan => + val child = replaceCountPlan(agg.child) + SimpleCountAggregateExec(agg.requiredChildDistributionExpressions, + agg.groupingExpressions, + agg.aggregateExpressions, + agg.aggregateAttributes, + agg.initialInputBufferOffset, + agg.resultExpressions, + child, + isDistinctCount = true, + columnStat) + case agg: HashAggregateExec if !isCountPlan => + val child = replaceCountPlan(agg.child) + SimpleCountAggregateExec(agg.requiredChildDistributionExpressions, + agg.groupingExpressions, + agg.aggregateExpressions, + agg.aggregateAttributes, + agg.initialInputBufferOffset, + agg.resultExpressions, + child, + isDistinctCount = true, + columnStat) + case p => + val children = plan.children.map(replaceCountPlan) + p.withNewChildren(children) + } +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/catalyst/catalog/CatalogColumnStat.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/catalyst/catalog/CatalogColumnStat.scala new file mode 100644 index 0000000000000000000000000000000000000000..abd0d9db515b2d0204a6071a72cfe2ad12303e94 --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/catalyst/catalog/CatalogColumnStat.scala @@ -0,0 +1,161 @@ +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramSerializer} +import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} +import org.apache.spark.sql.types._ + +import java.time.ZoneOffset +import scala.util.control.NonFatal + +case class CatalogColumnStat( + distinctCount: Option[BigInt] = None, + min: Option[String] = None, + max: Option[String] = None, + nullCount: Option[BigInt] = None, + avgLen: Option[Long] = None, + maxLen: Option[Long] = None, + histogram: Option[Histogram] = None, + version: Int = CatalogColumnStat.VERSION) { + + /** + * Returns a map from string to string that can be used to serialize the column stats. + * The key is the name of the column and name of the field (e.g. "colName.distinctCount"), + * and the value is the string representation for the value. + * min/max values are stored as Strings. They can be deserialized using + * [[CatalogColumnStat.fromExternalString]]. + * + * As part of the protocol, the returned map always contains a key called "version". + * Any of the fields that are null (None) won't appear in the map. + */ + def toMap(colName: String): Map[String, String] = { + val map = new scala.collection.mutable.HashMap[String, String] + map.put(s"${colName}.${CatalogColumnStat.KEY_VERSION}", CatalogColumnStat.VERSION.toString) + distinctCount.foreach { v => + map.put(s"${colName}.${CatalogColumnStat.KEY_DISTINCT_COUNT}", v.toString) + } + nullCount.foreach { v => + map.put(s"${colName}.${CatalogColumnStat.KEY_NULL_COUNT}", v.toString) + } + avgLen.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_AVG_LEN}", v.toString) } + maxLen.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MAX_LEN}", v.toString) } + min.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MIN_VALUE}", v) } + max.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MAX_VALUE}", v) } + histogram.foreach { h => + map.put(s"${colName}.${CatalogColumnStat.KEY_HISTOGRAM}", HistogramSerializer.serialize(h)) + } + map.toMap + } + + /** Convert [[CatalogColumnStat]] to [[ColumnStat]]. */ + def toPlanStat( + colName: String, + dataType: DataType): ColumnStat = + ColumnStat( + distinctCount = distinctCount, + min = min.map(CatalogColumnStat.fromExternalString(_, colName, dataType, version)), + max = max.map(CatalogColumnStat.fromExternalString(_, colName, dataType, version)), + nullCount = nullCount, + avgLen = avgLen, + maxLen = maxLen, + histogram = histogram, + version = version) +} + +object CatalogColumnStat extends Logging { + + // List of string keys used to serialize CatalogColumnStat + val KEY_VERSION = "version" + private val KEY_DISTINCT_COUNT = "distinctCount" + private val KEY_MIN_VALUE = "min" + private val KEY_MAX_VALUE = "max" + private val KEY_NULL_COUNT = "nullCount" + private val KEY_AVG_LEN = "avgLen" + private val KEY_MAX_LEN = "maxLen" + private val KEY_HISTOGRAM = "histogram" + + val VERSION = 2 + + private def getTimestampFormatter(isParsing: Boolean): TimestampFormatter = { + TimestampFormatter( + format = "yyyy-MM-dd HH:mm:ss.SSSSSS", + zoneId = ZoneOffset.UTC, + isParsing = isParsing) + } + + /** + * Converts from string representation of data type to the corresponding Catalyst data type. + */ + def fromExternalString(s: String, name: String, dataType: DataType, version: Int): Any = { + dataType match { + case BooleanType => s.toBoolean + case DateType if version == 1 => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) + case DateType => DateFormatter(ZoneOffset.UTC).parse(s) + case TimestampType if version == 1 => + DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) + case TimestampType => getTimestampFormatter(isParsing = true).parse(s) + case ByteType => s.toByte + case ShortType => s.toShort + case IntegerType => s.toInt + case LongType => s.toLong + case FloatType => s.toFloat + case DoubleType => s.toDouble + case _: DecimalType => Decimal(s) + case StringType => s + // This version of Spark does not use min/max for binary/string types so we ignore it. + case BinaryType => null + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $name of data type: $dataType.") + } + } + + /** + * Converts the given value from Catalyst data type to string representation of external + * data type. + */ + def toExternalString(v: Any, colName: String, dataType: DataType): String = { + val externalValue = dataType match { + case DateType => DateFormatter(ZoneOffset.UTC).format(v.asInstanceOf[Int]) + case TimestampType => getTimestampFormatter(isParsing = false).format(v.asInstanceOf[Long]) + case BooleanType | _: IntegralType | FloatType | DoubleType | StringType => v + case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal + // This version of Spark does not use min/max for binary/string types so we ignore it. + case _ => + throw new AnalysisException("Column statistics serialization is not supported for " + + s"column $colName of data type: $dataType.") + } + externalValue.toString + } + + + /** + * Creates a [[CatalogColumnStat]] object from the given map. + * This is used to deserialize column stats from some external storage. + * The serialization side is defined in [[CatalogColumnStat.toMap]]. + */ + def fromMap( + table: String, + colName: String, + map: Map[String, String]): Option[CatalogColumnStat] = { + + try { + Some(CatalogColumnStat( + distinctCount = map.get(s"${colName}.${KEY_DISTINCT_COUNT}").map(v => BigInt(v.toLong)), + min = map.get(s"${colName}.${KEY_MIN_VALUE}"), + max = map.get(s"${colName}.${KEY_MAX_VALUE}"), + nullCount = map.get(s"${colName}.${KEY_NULL_COUNT}").map(v => BigInt(v.toLong)), + avgLen = map.get(s"${colName}.${KEY_AVG_LEN}").map(_.toLong), + maxLen = map.get(s"${colName}.${KEY_MAX_LEN}").map(_.toLong), + histogram = map.get(s"${colName}.${KEY_HISTOGRAM}").map(HistogramSerializer.deserialize), + version = map(s"${colName}.${KEY_VERSION}").toInt + )) + } catch { + case NonFatal(e) => + logWarning(s"Failed to parse column statistics for column ${colName} in table $table", e) + None + } + } +} + diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 91bcafbad5f9ba962f6397803c5159e273d585a9..9eddc2eb984d49325e23ebf9f0e9170a2b87f404 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution -import java.util.concurrent.TimeUnit._ +import com.huawei.boostkit.omnioffload.spark.NdpPluginEnableFlag +import java.util.concurrent.TimeUnit._ import scala.collection.mutable.HashMap - import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.Path - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.BucketSpec @@ -31,10 +30,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.execution.datasources.{FileScanRDDPushDown, _} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.ndp.NdpSupport +import org.apache.spark.sql.execution.ndp.NdpConf.{getNdpPartialPushdown, getNdpPartialPushdownEnable, getOptimizerPushDownEnable, getOptimizerPushDownPreThreadTask, getOptimizerPushDownThreshold, getTaskTimeout} +import org.apache.spark.sql.execution.ndp.{NdpConf, NdpSupport} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, Filter} import org.apache.spark.sql.types.StructType @@ -42,6 +42,10 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils import org.apache.spark.util.collection.BitSet +import java.util.UUID + + + trait DataSourceScanExec extends LeafExecNode { def relation: BaseRelation def tableIdentifier: Option[TableIdentifier] @@ -160,7 +164,7 @@ case class RowDataSourceScanExec( * @param disableBucketedScan Disable bucketed scan based on physical query plan, see rule * [[DisableUnnecessaryBucketedScan]] for details. */ -case class FileSourceScanExec( +abstract class BaseFileSourceScanExec( @transient relation: HadoopFsRelation, output: Seq[Attribute], requiredSchema: StructType, @@ -169,11 +173,21 @@ case class FileSourceScanExec( optionalNumCoalescedBuckets: Option[Int], dataFilters: Seq[Expression], tableIdentifier: Option[TableIdentifier], - partiTionColumn: Seq[Attribute], - disableBucketedScan: Boolean = false + partitionColumn: Seq[Attribute], + disableBucketedScan: Boolean = false, + var runtimePartSum: Int = -1, + var runtimePushDownSum: Int = -1 ) extends DataSourceScanExec with NdpSupport { + def setRuntimePartSum(runtimePartSum: Int): Unit ={ + this.runtimePartSum = runtimePartSum + } + + def setRuntimePushDownSum(runtimePushDownSum: Int): Unit ={ + this.runtimePushDownSum = runtimePushDownSum + } + // Note that some vals referring the file-based relation are lazy intentionally // so that this plan can be canonicalized on executor side too. See SPARK-23731. override lazy val supportsColumnar: Boolean = { @@ -573,13 +587,8 @@ case class FileSourceScanExec( FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) } } - if (isPushDown) { - new FileScanRDDPushDown(fsRelation.sparkSession, filePartitions, requiredSchema, output, - relation.dataSchema, ndpOperators, partiTionColumn, supportsColumnar, fsRelation.fileFormat) - } else { - new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions) - } + RDDPushDown(fsRelation, filePartitions, readFile) } /** @@ -620,13 +629,7 @@ case class FileSourceScanExec( val partitions = FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes) - if (isPushDown) { - new FileScanRDDPushDown(fsRelation.sparkSession, partitions, requiredSchema, output, - relation.dataSchema, ndpOperators, partiTionColumn, supportsColumnar, fsRelation.fileFormat) - } else { - // TODO 重写一个FileScanRDD 重新调用 - new FileScanRDD(fsRelation.sparkSession, readFile, partitions) - } + RDDPushDown(fsRelation, partitions, readFile) } // Filters unused DynamicPruningExpression expressions - one which has been replaced @@ -655,8 +658,91 @@ case class FileSourceScanExec( optionalNumCoalescedBuckets, QueryPlan.normalizePredicates(dataFilters, filterOutput), None, - partiTionColumn.map(QueryPlan.normalizeExpressions(_, output)), + partitionColumn.map(QueryPlan.normalizeExpressions(_, output)), disableBucketedScan ) } + + private def RDDPushDown(fsRelation: HadoopFsRelation, filePartitions: Seq[FilePartition], readFile: (PartitionedFile) => Iterator[InternalRow]): RDD[InternalRow] = { + if (isPushDown) { + val partialCondition = allFilterExecInfo.nonEmpty && aggExeInfos.isEmpty && limitExeInfo.isEmpty && getNdpPartialPushdownEnable(fsRelation.sparkSession) + val partialPdRate = getNdpPartialPushdown(fsRelation.sparkSession) + var partialChildOutput = Seq[Attribute]() + if (partialCondition) { + partialChildOutput = allFilterExecInfo.head.child.output + logInfo(s"partial push down rate: ${partialPdRate}") + } + def isNdpPluginOptimizerPush: Boolean = getOptimizerPushDownEnable(fsRelation.sparkSession) && + NdpPluginEnableFlag.isEnable(fsRelation.sparkSession) + def taskTotal: Int = if(runtimePartSum > 0){ + runtimePartSum + } else { + filePartitions.size + } + val pushDownTotal: Int = if(runtimePushDownSum > 0) { + runtimePushDownSum + } else { + getOptimizerPushDownThreshold(fsRelation.sparkSession) + } + val preThreadTask: Int = getOptimizerPushDownPreThreadTask(fsRelation.sparkSession) + + val omniGroupId = String.valueOf(UUID.randomUUID); + + new FileScanRDDPushDown(fsRelation.sparkSession, filePartitions, requiredSchema, output, + relation.dataSchema, ndpOperators, partitionColumn, supportsColumnar, fsRelation.fileFormat, + readFile, partialCondition, partialPdRate, zkRate, partialChildOutput, isNdpPluginOptimizerPush, pushDownTotal, + taskTotal, preThreadTask, omniGroupId) + } else { + new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions) + } + } +} + +case class FileSourceScanExec( + @transient relation: HadoopFsRelation, + output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + partitionColumn: Seq[Attribute], + disableBucketedScan: Boolean = false) + extends BaseFileSourceScanExec( + relation, + output, + requiredSchema, + partitionFilters, + optionalBucketSet, + optionalNumCoalescedBuckets, + dataFilters, + tableIdentifier, + partitionColumn, + disableBucketedScan) { + } + +case class NdpFileSourceScanExec( + @transient relation: HadoopFsRelation, + output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + partitionColumn: Seq[Attribute], + disableBucketedScan: Boolean = false) + extends BaseFileSourceScanExec( + relation, + output, + requiredSchema, + partitionFilters, + optionalBucketSet, + optionalNumCoalescedBuckets, + dataFilters, + tableIdentifier, + partitionColumn, + disableBucketedScan) { + } diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/RadixSortExec.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/RadixSortExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..be267fa5e91b10483d6a94dbe95ca3d81582c843 --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/RadixSortExec.scala @@ -0,0 +1,196 @@ +package org.apache.spark.sql.execution + +import java.util.concurrent.TimeUnit._ +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf + +case class RadixSortExec( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan, + testSpillFrequency: Int = 0) + extends UnaryExecNode with BlockingOperatorWithCodegen { + + override def nodeName: String = "OmniRadixSort" + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder + + // sort performed is local within a given partition so will retain + // child operator's partitioning + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + private val enableRadixSort = SQLConf.get.enableRadixSort + + override lazy val metrics = Map( + "sortTime" -> SQLMetrics.createTimingMetric(sparkContext, "sort time"), + "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), + "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) + + private[sql] var rowSorter: UnsafeExternalRadixRowSorter = _ + + /** + * This method gets invoked only once for each SortExec instance to initialize an + * UnsafeExternalRowSorter, both `plan.execute` and code generation are using it. + * In the code generation code path, we need to call this function outside the class so we + * should make it public. + */ + def createSorter(): UnsafeExternalRadixRowSorter = { + val ordering = RowOrdering.create(sortOrder, output) + + // The comparator for comparing prefix + // 转换下排序字段的字段对象(AttributeReference->BoundReference,只包含字段序号,数据类型,是否可以是null) + + // TODO 修改1,这里取sortOrder所有的,限制两个 + val boundSortExpressions = sortOrder.map(x => BindReferences.bindReference(x, output)) + // 比较器按这几个维度分类:有符号/无符号;NULL排最前/NULL排最后;升序/降序 + val prefixComparators = boundSortExpressions.map(SortPrefixUtils.getPrefixComparator) + + // The generator for prefix + // 为各种数据类型生成可long类型的值作为前缀比较 + val prefixExprs = boundSortExpressions.map(SortPrefix) + val prefixProjections = prefixExprs.map(x => UnsafeProjection.create(Seq(x))) + val prefixComputers = prefixProjections.zip(prefixExprs) + .map { case (prefixProjection, prefixExpr) => + new UnsafeExternalRadixRowSorter.PrefixComputer { + private val result = new UnsafeExternalRadixRowSorter.PrefixComputer.Prefix + + override def computePrefix(row: InternalRow): + UnsafeExternalRadixRowSorter.PrefixComputer.Prefix = { + val prefix = prefixProjection.apply(row) + result.isNull = prefix.isNullAt(0) + result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0) + result + } + } + } + + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + rowSorter = UnsafeExternalRadixRowSorter.create( + schema, ordering, + scala.collection.JavaConverters.seqAsJavaList(prefixComparators), + scala.collection.JavaConverters.seqAsJavaList(prefixComputers), + pageSize, true) + + if (testSpillFrequency > 0) { + rowSorter.setTestSpillFrequency(testSpillFrequency) + } + rowSorter + } + + protected override def doExecute(): RDD[InternalRow] = { + val peakMemory = longMetric("peakMemory") + val spillSize = longMetric("spillSize") + val sortTime = longMetric("sortTime") + + child.execute().mapPartitionsInternal { iter => + val sorter = createSorter() + + val metrics = TaskContext.get().taskMetrics() + // Remember spill data size of this task before execute this operator so that we can + // figure out how many bytes we spilled for this operator. + val spillSizeBefore = metrics.memoryBytesSpilled + val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + sortTime += NANOSECONDS.toMillis(sorter.getSortTimeNanos) + peakMemory += sorter.getPeakMemoryUsage + spillSize += metrics.memoryBytesSpilled - spillSizeBefore + metrics.incPeakExecutionMemory(sorter.getPeakMemoryUsage) + + sortedIterator + } + } + + override def usedInputs: AttributeSet = AttributeSet(Seq.empty) + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() + } + + // Name of sorter variable used in codegen. + private var sorterVariable: String = _ + + override protected def doProduce(ctx: CodegenContext): String = { + val needToSort = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") + + // Initialize the class member variables. This includes the instance of the Sorter and + // the iterator to return sorted rows. + val thisPlan = ctx.addReferenceObj("plan", this) + // Inline mutable state since not many Sort operations in a task + sorterVariable = ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, "sorter", + v => s"$v = $thisPlan.createSorter();", forceInline = true) + val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics", + v => s"$v = org.apache.spark.TaskContext.get().taskMetrics();", forceInline = true) + val sortedIterator = ctx.addMutableState("scala.collection.Iterator", "sortedIter", + forceInline = true) + + val addToSorter = ctx.freshName("addToSorter") + val addToSorterFuncName = ctx.addNewFunction(addToSorter, + s""" + | private void $addToSorter() throws java.io.IOException { + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | } + """.stripMargin.trim) + + val outputRow = ctx.freshName("outputRow") + val peakMemory = metricTerm(ctx, "peakMemory") + val spillSize = metricTerm(ctx, "spillSize") + val spillSizeBefore = ctx.freshName("spillSizeBefore") + val sortTime = metricTerm(ctx, "sortTime") + s""" + | if ($needToSort) { + | long $spillSizeBefore = $metrics.memoryBytesSpilled(); + | $addToSorterFuncName(); + | $sortedIterator = $sorterVariable.sort(); + | $sortTime.add($sorterVariable.getSortTimeNanos() / $NANOS_PER_MILLIS); + | $peakMemory.add($sorterVariable.getPeakMemoryUsage()); + | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); + | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); + | $needToSort = false; + | } + | + | while ($limitNotReachedCond $sortedIterator.hasNext()) { + | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); + | ${consume(ctx, null, outputRow)} + | if (shouldStop()) return; + | } + """.stripMargin.trim + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + s""" + |${row.code} + |$sorterVariable.insertRow((UnsafeRow)${row.value}); + """.stripMargin + } + + /** + * In SortExec, we overwrites cleanupResources to close UnsafeExternalRowSorter. + */ + override protected[sql] def cleanupResources(): Unit = { + if (rowSorter != null) { + // There's possible for rowSorter is null here, for example, in the scenario of empty + // iterator in the current task, the downstream physical node(like SortMergeJoinExec) will + // trigger cleanupResources before rowSorter initialized in createSorter. + rowSorter.cleanupResources() + } + super.cleanupResources() + } + + protected def withNewChildInternal(newChild: SparkPlan): RadixSortExec = + copy(child = newChild) + + override def supportCodegen: Boolean = false +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/SimpleCountFileScanExec.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/SimpleCountFileScanExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..94de5fe69c32bf040e7dedcff1d2a2ab0e03e17e --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/SimpleCountFileScanExec.scala @@ -0,0 +1,590 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.hadoop.fs.Path +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} +import org.apache.spark.util.collection.BitSet +import org.apache.spark.util.{SerializableConfiguration, Utils} + +import java.net.URI +import java.util.concurrent.TimeUnit._ +import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` +import scala.collection.mutable.HashMap + +//Only inputRDD and doExecute use are modified, other functions are the same as function in class FileSourceScanExec. +case class SimpleCountFileScanExec( + @transient relation: HadoopFsRelation, + output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + disableBucketedScan: Boolean = false, + isEmptyIter: Boolean = false) + extends DataSourceScanExec { + + // Note that some vals referring the file-based relation are lazy intentionally + // so that this plan can be canonicalized on executor side too. See SPARK-23731. + override lazy val supportsColumnar: Boolean = { + relation.fileFormat.supportBatch(relation.sparkSession, schema) + } + + private lazy val needsUnsafeRowConversion: Boolean = { + if (relation.fileFormat.isInstanceOf[ParquetSource]) { + sqlContext.conf.parquetVectorizedReaderEnabled + } else { + false + } + } + + override def vectorTypes: Option[Seq[String]] = + relation.fileFormat.vectorTypes( + requiredSchema = requiredSchema, + partitionSchema = relation.partitionSchema, + relation.sparkSession.sessionState.conf) + + private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty + + /** + * Send the driver-side metrics. Before calling this function, selectedPartitions has + * been initialized. See SPARK-26327 for more details. + */ + private def sendDriverMetrics(): Unit = { + driverMetrics.foreach(e => metrics(e._1).add(e._2)) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, + metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) + } + + private def isDynamicPruningFilter(e: Expression): Boolean = + e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + + @transient lazy val selectedPartitions: Array[PartitionDirectory] = { + val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) + val startTime = System.nanoTime() + val ret = + relation.location.listFiles( + partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) + setFilesNumAndSizeMetric(ret, true) + val timeTakenMs = NANOSECONDS.toMillis( + (System.nanoTime() - startTime) + optimizerMetadataTimeNs) + driverMetrics("metadataTime") = timeTakenMs + ret + }.toArray + + // We can only determine the actual partitions at runtime when a dynamic partition filter is + // present. This is because such a filter relies on information that is only available at run + // time (for instance the keys used in the other side of a join). + @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { + val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) + + if (dynamicPartitionFilters.nonEmpty) { + val startTime = System.nanoTime() + // call the file index for the files matching all filters except dynamic partition filters + val predicate = dynamicPartitionFilters.reduce(And) + val partitionColumns = relation.partitionSchema + val boundPredicate = Predicate.create(predicate.transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }, Nil) + val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) + setFilesNumAndSizeMetric(ret, false) + val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 + driverMetrics("pruningTime") = timeTakenMs + ret + } else { + selectedPartitions + } + } + + /** + * [[partitionFilters]] can contain subqueries whose results are available only at runtime so + * accessing [[selectedPartitions]] should be guarded by this method during planning + */ + private def hasPartitionsAvailableAtRunTime: Boolean = { + partitionFilters.exists(ExecSubqueryExpression.hasSubquery) + } + + private def toAttribute(colName: String): Option[Attribute] = + output.find(_.name == colName) + + // exposed for testing + lazy val bucketedScan: Boolean = { + if (relation.sparkSession.sessionState.conf.bucketingEnabled && relation.bucketSpec.isDefined + && !disableBucketedScan) { + val spec = relation.bucketSpec.get + val bucketColumns = spec.bucketColumnNames.flatMap(n => toAttribute(n)) + bucketColumns.size == spec.bucketColumnNames.size + } else { + false + } + } + + override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { + if (bucketedScan) { + // For bucketed columns: + // ----------------------- + // `HashPartitioning` would be used only when: + // 1. ALL the bucketing columns are being read from the table + // + // For sorted columns: + // --------------------- + // Sort ordering should be used when ALL these criteria's match: + // 1. `HashPartitioning` is being used + // 2. A prefix (or all) of the sort columns are being read from the table. + // + // Sort ordering would be over the prefix subset of `sort columns` being read + // from the table. + // e.g. + // Assume (col0, col2, col3) are the columns read from the table + // If sort columns are (col0, col1), then sort ordering would be considered as (col0) + // If sort columns are (col1, col0), then sort ordering would be empty as per rule #2 + // above + val spec = relation.bucketSpec.get + val bucketColumns = spec.bucketColumnNames.flatMap(n => toAttribute(n)) + val numPartitions = optionalNumCoalescedBuckets.getOrElse(spec.numBuckets) + val partitioning = HashPartitioning(bucketColumns, numPartitions) + val sortColumns = + spec.sortColumnNames.map(x => toAttribute(x)).takeWhile(x => x.isDefined).map(_.get) + val shouldCalculateSortOrder = + conf.getConf(SQLConf.LEGACY_BUCKETED_TABLE_SCAN_OUTPUT_ORDERING) && + sortColumns.nonEmpty && + !hasPartitionsAvailableAtRunTime + + val sortOrder = if (shouldCalculateSortOrder) { + // In case of bucketing, its possible to have multiple files belonging to the + // same bucket in a given relation. Each of these files are locally sorted + // but those files combined together are not globally sorted. Given that, + // the RDD partition will not be sorted even if the relation has sort columns set + // Current solution is to check if all the buckets have a single file in it + + val files = selectedPartitions.flatMap(partition => partition.files) + val bucketToFilesGrouping = + files.map(_.getPath.getName).groupBy(file => BucketingUtils.getBucketId(file)) + val singleFilePartitions = bucketToFilesGrouping.forall(p => p._2.length <= 1) + + // TODO SPARK-24528 Sort order is currently ignored if buckets are coalesced. + if (singleFilePartitions && optionalNumCoalescedBuckets.isEmpty) { + // TODO Currently Spark does not support writing columns sorting in descending order + // so using Ascending order. This can be fixed in future + sortColumns.map(attribute => SortOrder(attribute, Ascending)) + } else { + Nil + } + } else { + Nil + } + (partitioning, sortOrder) + } else { + (UnknownPartitioning(0), Nil) + } + } + + @transient + private lazy val pushedDownFilters = { + val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) + dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + } + + override lazy val metadata: Map[String, String] = { + def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") + + val location = relation.location + val locationDesc = + location.getClass.getSimpleName + + Utils.buildLocationMetadata(location.rootPaths, maxMetadataValueLength) + val metadata = + Map( + "Format" -> relation.fileFormat.toString, + "ReadSchema" -> requiredSchema.catalogString, + "Batched" -> supportsColumnar.toString, + "PartitionFilters" -> seqToString(partitionFilters), + "PushedFilters" -> seqToString(pushedDownFilters), + "DataFilters" -> seqToString(dataFilters), + "Location" -> locationDesc) + + // TODO(SPARK-32986): Add bucketed scan info in explain output of FileSourceScanExec + if (bucketedScan) { + relation.bucketSpec.map { spec => + val numSelectedBuckets = optionalBucketSet.map { b => + b.cardinality() + } getOrElse { + spec.numBuckets + } + metadata + ("SelectedBucketsCount" -> + (s"$numSelectedBuckets out of ${spec.numBuckets}" + + optionalNumCoalescedBuckets.map { b => s" (Coalesced to $b)" }.getOrElse(""))) + } getOrElse { + metadata + } + } else { + metadata + } + } + + override def verboseStringWithOperatorId(): String = { + val metadataStr = metadata.toSeq.sorted.filterNot { + case (_, value) if (value.isEmpty || value.equals("[]")) => true + case (key, _) if (key.equals("DataFilters") || key.equals("Format")) => true + case (_, _) => false + }.map { + case (key, _) if (key.equals("Location")) => + val location = relation.location + val numPaths = location.rootPaths.length + val abbreviatedLocation = if (numPaths <= 1) { + location.rootPaths.mkString("[", ", ", "]") + } else { + "[" + location.rootPaths.head + s", ... ${numPaths - 1} entries]" + } + s"$key: ${location.getClass.getSimpleName} ${redact(abbreviatedLocation)}" + case (key, value) => s"$key: ${redact(value)}" + } + + s""" + |$formattedNodeName + |${ExplainUtils.generateFieldString("Output", output)} + |${metadataStr.mkString("\n")} + |""".stripMargin + } + + lazy val inputRDD: RDD[InternalRow] = { + val readFile: (PartitionedFile) => Iterator[InternalRow] = if (isEmptyIter) { + emptyReadPartitionValues() + } else { + simpleReadPartitionValues() + } + val readRDD = if (bucketedScan) { + createBucketedReadRDD(relation.bucketSpec.get, readFile, dynamicallySelectedPartitions, + relation) + } else { + createNonBucketedReadRDD(readFile, dynamicallySelectedPartitions, relation) + } + sendDriverMetrics() + readRDD + } + + def emptyReadPartitionValues(): PartitionedFile => Iterator[InternalRow] = { + (file: PartitionedFile) => Iterator.empty + } + + def simpleReadPartitionValues(): PartitionedFile => Iterator[InternalRow] = { + val resultSchema = StructType(relation.dataSchema.fields ++ requiredSchema.fields) + val returningBatch = supportBatch(relation.sparkSession, resultSchema) + val broadcastedHadoopConf = + relation.sparkSession.sparkContext.broadcast( + new SerializableConfiguration( + relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options))) + (file: PartitionedFile) => { + val filePath = new Path(new URI(file.filePath)) + val split = + new org.apache.parquet.hadoop.ParquetInputSplit( + filePath, + file.start, + file.start + file.length, + file.length, + Array.empty, + null) + val sharedConf = broadcastedHadoopConf.value.value + + val footer = ParquetFileReader.readFooter(sharedConf, filePath, + ParquetMetadataConverter.range(split.getStart, split.getEnd)) + val count = footer.getBlocks.map(_.getRowCount).sum + val batch = createBatch(count) + val iter = new Iterator[Object] { + var batchId = 0 + + override def hasNext: Boolean = { + if (batchId >= 1) { + return false + } + true + } + + override def next(): Object = { + batchId = batchId + 1 + if (returningBatch) return batch + batch.getRow(0) + } + } + iter.asInstanceOf[Iterator[InternalRow]] + } + } + + def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { + val conf = sparkSession.sessionState.conf + conf.parquetVectorizedReaderEnabled && conf.wholeStageEnabled && + schema.length <= conf.wholeStageMaxNumFields && + schema.forall(_.dataType.isInstanceOf[AtomicType]) + } + + val virtualField = StructField("virtualField", LongType, false, Metadata.empty) + + def createBatch(count: Long): ColumnarBatch = { + var batchSchema = new StructType() + batchSchema = batchSchema.add(virtualField) + val columnVectors = OnHeapColumnVector.allocateColumns(1, batchSchema) + val columnarBatch = new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]]) + for (vector <- columnVectors) { + vector.reset() + vector.putLong(0, count) + } + columnarBatch.setNumRows(1) + columnarBatch + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + inputRDD :: Nil + } + + /** SQL metrics generated only for scans using dynamic partition pruning. */ + private lazy val staticMetrics = if (partitionFilters.exists(isDynamicPruningFilter)) { + Map("staticFilesNum" -> SQLMetrics.createMetric(sparkContext, "static number of files read"), + "staticFilesSize" -> SQLMetrics.createSizeMetric(sparkContext, "static size of files read")) + } else { + Map.empty[String, SQLMetric] + } + + /** Helper for computing total number and size of files in selected partitions. */ + private def setFilesNumAndSizeMetric( + partitions: Seq[PartitionDirectory], + static: Boolean): Unit = { + val filesNum = partitions.map(_.files.size.toLong).sum + val filesSize = partitions.map(_.files.map(_.getLen).sum).sum + if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { + driverMetrics("numFiles") = filesNum + driverMetrics("filesSize") = filesSize + } else { + driverMetrics("staticFilesNum") = filesNum + driverMetrics("staticFilesSize") = filesSize + } + if (relation.partitionSchemaOption.isDefined) { + driverMetrics("numPartitions") = partitions.length + } + } + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of files read"), + "metadataTime" -> SQLMetrics.createTimingMetric(sparkContext, "metadata time"), + "filesSize" -> SQLMetrics.createSizeMetric(sparkContext, "size of files read") + ) ++ { + // Tracking scan time has overhead, we can't afford to do it for each row, and can only do + // it for each batch. + if (supportsColumnar) { + Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) + } else { + None + } + } ++ { + if (relation.partitionSchemaOption.isDefined) { + Map( + "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions read"), + "pruningTime" -> + SQLMetrics.createTimingMetric(sparkContext, "dynamic partition pruning time")) + } else { + Map.empty[String, SQLMetric] + } + } ++ staticMetrics + + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + if (needsUnsafeRowConversion) { + inputRDD.mapPartitionsWithIndexInternal { (index, iter) => + var batchSchema = new StructType() + batchSchema = batchSchema.add(virtualField) + val toUnsafe = UnsafeProjection.create(batchSchema) + toUnsafe.initialize(index) + iter.map { row => + numOutputRows += 1 + toUnsafe(row) + } + } + } else { + inputRDD.mapPartitionsInternal { iter => + iter.map { row => + numOutputRows += 1 + row + } + } + } + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val scanTime = longMetric("scanTime") + inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => + new Iterator[ColumnarBatch] { + + override def hasNext: Boolean = { + // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. + val startNs = System.nanoTime() + val res = batches.hasNext + scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) + res + } + + override def next(): ColumnarBatch = { + val batch = batches.next() + numOutputRows += batch.numRows() + batch + } + } + } + } + + override val nodeNamePrefix: String = "File" + + /** + * Create an RDD for bucketed reads. + * The non-bucketed variant of this function is [[createNonBucketedReadRDD]]. + * + * The algorithm is pretty simple: each RDD partition being returned should include all the files + * with the same bucket id from all the given Hive partitions. + * + * @param bucketSpec the bucketing spec. + * @param readFile a function to read each (part of a) file. + * @param selectedPartitions Hive-style partition that are part of the read. + * @param fsRelation [[HadoopFsRelation]] associated with the read. + */ + private def createBucketedReadRDD( + bucketSpec: BucketSpec, + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") + val filesGroupedToBuckets = + selectedPartitions.flatMap { p => + p.files.map { f => + PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values) + } + }.groupBy { f => + BucketingUtils + .getBucketId(new Path(f.filePath).getName) + .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) + } + + // TODO(SPARK-32985): Decouple bucket filter pruning and bucketed table scan + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter { + f => bucketSet.get(f._1) + } + } else { + filesGroupedToBuckets + } + + val filePartitions = optionalNumCoalescedBuckets.map { numCoalescedBuckets => + logInfo(s"Coalescing to ${numCoalescedBuckets} buckets") + val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) + Seq.tabulate(numCoalescedBuckets) { bucketId => + val partitionedFiles = coalescedBuckets.get(bucketId).map { + _.values.flatten.toArray + }.getOrElse(Array.empty) + FilePartition(bucketId, partitionedFiles) + } + }.getOrElse { + Seq.tabulate(bucketSpec.numBuckets) { bucketId => + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) + } + } + + new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions) + } + + /** + * Create an RDD for non-bucketed reads. + * The bucketed variant of this function is [[createBucketedReadRDD]]. + * + * @param readFile a function to read each (part of a) file. + * @param selectedPartitions Hive-style partition that are part of the read. + * @param fsRelation [[HadoopFsRelation]] associated with the read. + */ + private def createNonBucketedReadRDD( + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes + val maxSplitBytes = + FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) + logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + + s"open cost is considered as scanning $openCostInBytes bytes.") + + val splitFiles = selectedPartitions.flatMap { partition => + partition.files.flatMap { file => + // getPath() is very expensive so we only want to call it once in this block: + val filePath = file.getPath + val isSplitable = relation.fileFormat.isSplitable( + relation.sparkSession, relation.options, filePath) + PartitionedFileUtil.splitFiles( + sparkSession = relation.sparkSession, + file = file, + filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values + ) + } + }.sortBy(_.length)(implicitly[Ordering[Long]].reverse) + + val partitions = + FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes) + + new FileScanRDD(fsRelation.sparkSession, readFile, partitions) + } + + // Filters unused DynamicPruningExpression expressions - one which has been replaced + // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning + private def filterUnusedDynamicPruningExpressions(predicates: + Seq[Expression]): Seq[Expression] = { + predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) + } + + override def doCanonicalize(): SimpleCountFileScanExec = { + SimpleCountFileScanExec( + relation, + output.map(QueryPlan.normalizeExpressions(_, output)), + requiredSchema, + QueryPlan.normalizePredicates( + filterUnusedDynamicPruningExpressions(partitionFilters), output), + optionalBucketSet, + optionalNumCoalescedBuckets, + QueryPlan.normalizePredicates(dataFilters, output), + None, + disableBucketedScan) + } +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ba17441fac2a7a7b3bd782705cb2a222af22855a..82265e2da916fc63f357cc00f5b16d101229f12a 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, AnalysisException, Strategy} +import org.apache.spark.sql.{AnalysisException, Strategy, execution} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.aggregate.AggUtils import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.{REPARTITION, REPARTITION_WITH_NUM, ShuffleExchangeExec} +import org.apache.spark.sql.execution.ndp.NdpFilterEstimation import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlan @@ -537,7 +538,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalOperation(projectList, filters, mem: InMemoryRelation) => val condition = filters.reduceLeftOption(And) val selectivity = if (condition.nonEmpty) { - FilterEstimation(Filter(condition.get, mem)).calculateFilterSelectivity(condition.get) + NdpFilterEstimation(FilterEstimation(Filter(condition.get, mem))).calculateFilterSelectivity(condition.get) } else { None } @@ -605,8 +606,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil - case r: RunnableCommand => ExecutedCommandExec(r) :: Nil - + case r: RunnableCommand => + r match { + case cmd: AnalyzeColumnCommand if conf.getConfString("spark.sql.ndp.string.analyze.enabled","true").toBoolean => + ExecutedCommandExec(NdpAnalyzeColumnCommand(cmd.tableIdent, cmd.columnNames, cmd.allColumns)) :: Nil + case _ => + ExecutedCommandExec(r) :: Nil + } case MemoryPlan(sink, output) => val encoder = RowEncoder(StructType.fromAttributes(output)) val toRow = encoder.createSerializer() @@ -687,12 +693,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Project(projectList, child) => execution.ProjectExec(projectList, planLater(child)) :: Nil case l @ logical.Filter(condition, child) => - val selectivity = FilterEstimation(l).calculateFilterSelectivity(l.condition) + val selectivity = NdpFilterEstimation(FilterEstimation(l)).calculateFilterSelectivity(l.condition) execution.FilterExec(condition, planLater(child), selectivity) :: Nil case f: logical.TypedFilter => val condition = f.typedCondition(f.deserializer) val filter = Filter(condition, f.child) - val selectivity = FilterEstimation(filter).calculateFilterSelectivity(condition) + val selectivity = NdpFilterEstimation(FilterEstimation(filter)).calculateFilterSelectivity(condition) execution.FilterExec(condition, planLater(f.child), selectivity) :: Nil case e @ logical.Expand(_, _, child) => execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/aggregate/SimpleCountAggregateExec.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/aggregate/SimpleCountAggregateExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..45d942a0be252076651ea6bb35ded3101e8950d3 --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/aggregate/SimpleCountAggregateExec.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.TaskContext +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.KVIterator + +import java.util.concurrent.TimeUnit._ + +/** + * Hash-based aggregate operator that can also fallback to sorting when data exceeds memory size. + */ +case class SimpleCountAggregateExec( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan, + isDistinctCount: Boolean = false, + var columnStat: BigInt = -1) + extends BaseAggregateExec + with BlockingOperatorWithCodegen { + + require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) + + override def nodeName: String = "SimpleCountAggregate" + + override lazy val allAttributes: AttributeSeq = + child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), + "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), + "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build"), + "avgHashProbe" -> + SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters")) + + // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash + // map and/or the sort-based aggregation once it has processed a given number of input rows. + private val testFallbackStartsAt: Option[(Int, Int)] = { + sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match { + case null | "" => None + case fallbackStartsAt => + val splits = fallbackStartsAt.split(",").map(_.trim) + Some((splits.head.toInt, splits.last.toInt)) + } + } + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + val peakMemory = longMetric("peakMemory") + val spillSize = longMetric("spillSize") + val avgHashProbe = longMetric("avgHashProbe") + val aggTime = longMetric("aggTime") + + child.execute().mapPartitionsWithIndex { (partIndex, iter) => + + val beforeAgg = System.nanoTime() + val hasInput = iter.hasNext + val res = if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator.empty + } else { + val aggregationIterator = + new SimpleCountTungstenAggIter( + partIndex, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + (expressions, inputSchema) => + MutableProjection.create(expressions, inputSchema), + inputAttributes, + iter, + testFallbackStartsAt, + numOutputRows, + peakMemory, + spillSize, + avgHashProbe, + isDistinctCount, + columnStat) + if (!hasInput && groupingExpressions.isEmpty) { + numOutputRows += 1 + Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) + } else { + aggregationIterator + } + } + aggTime += NANOSECONDS.toMillis(System.nanoTime() - beforeAgg) + res + } + } + + private val modes = aggregateExpressions.map(_.mode).distinct + + override def usedInputs: AttributeSet = inputSet + + override def supportCodegen: Boolean = false + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() + } + + private val groupingAttributes = groupingExpressions.map(_.toAttribute) + private val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + private val declFunctions = aggregateExpressions.map(_.aggregateFunction) + .filter(_.isInstanceOf[DeclarativeAggregate]) + .map(_.asInstanceOf[DeclarativeAggregate]) + private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + + /** + * This is called by generated Java class, should be public. + */ + def createHashMap(): UnsafeFixedWidthAggregationMap = { + // create initialized aggregate buffer + val initExpr = declFunctions.flatMap(f => f.initialValues) + val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow) + + // create hashMap + new UnsafeFixedWidthAggregationMap( + initialBuffer, + bufferSchema, + groupingKeySchema, + TaskContext.get(), + 1024 * 16, // initial capacity + TaskContext.get().taskMemoryManager().pageSizeBytes + ) + } + + def getTaskMemoryManager(): TaskMemoryManager = { + TaskContext.get().taskMemoryManager() + } + + def getEmptyAggregationBuffer(): InternalRow = { + val initExpr = declFunctions.flatMap(f => f.initialValues) + val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow) + initialBuffer + } + + /** + * This is called by generated Java class, should be public. + */ + def createUnsafeJoiner(): UnsafeRowJoiner = { + GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + } + + /** + * Called by generated Java class to finish the aggregate and return a KVIterator. + */ + def finishAggregate( + hashMap: UnsafeFixedWidthAggregationMap, + sorter: UnsafeKVExternalSorter, + peakMemory: SQLMetric, + spillSize: SQLMetric, + avgHashProbe: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = { + + // update peak execution memory + val mapMemory = hashMap.getPeakMemoryUsedBytes + val sorterMemory = Option(sorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) + val maxMemory = Math.max(mapMemory, sorterMemory) + val metrics = TaskContext.get().taskMetrics() + peakMemory.add(maxMemory) + metrics.incPeakExecutionMemory(maxMemory) + + // Update average hashmap probe + avgHashProbe.set(hashMap.getAvgHashProbeBucketListIterations) + + if (sorter == null) { + // not spilled + return hashMap.iterator() + } + + // merge the final hashMap into sorter + sorter.merge(hashMap.destructAndCreateExternalSorter()) + hashMap.free() + val sortedIter = sorter.sortedIterator() + + // Create a KVIterator based on the sorted iterator. + new KVIterator[UnsafeRow, UnsafeRow] { + + // Create a MutableProjection to merge the rows of same key together + val mergeExpr = declFunctions.flatMap(_.mergeExpressions) + val mergeProjection = MutableProjection.create( + mergeExpr, + aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes)) + val joinedRow = new JoinedRow() + + var currentKey: UnsafeRow = null + var currentRow: UnsafeRow = null + var nextKey: UnsafeRow = if (sortedIter.next()) { + sortedIter.getKey + } else { + null + } + + override def next(): Boolean = { + if (nextKey != null) { + currentKey = nextKey.copy() + currentRow = sortedIter.getValue.copy() + nextKey = null + // use the first row as aggregate buffer + mergeProjection.target(currentRow) + + // merge the following rows with same key together + var findNextGroup = false + while (!findNextGroup && sortedIter.next()) { + val key = sortedIter.getKey + if (currentKey.equals(key)) { + mergeProjection(joinedRow(currentRow, sortedIter.getValue)) + } else { + // We find a new group. + findNextGroup = true + nextKey = key + } + } + + true + } else { + spillSize.add(sorter.getSpillSize) + false + } + } + + override def getKey: UnsafeRow = currentKey + override def getValue: UnsafeRow = currentRow + override def close(): Unit = { + sortedIter.close() + } + } + } + + override def verboseString(maxFields: Int): String = toString(verbose = true, maxFields) + + override def simpleString(maxFields: Int): String = toString(verbose = false, maxFields) + + private def toString(verbose: Boolean, maxFields: Int): String = { + val allAggregateExpressions = aggregateExpressions + + testFallbackStartsAt match { + case None => + val keyString = truncatedString(groupingExpressions, "[", ", ", "]", maxFields) + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]", maxFields) + val outputString = truncatedString(output, "[", ", ", "]", maxFields) + if (verbose) { + s"HashAggregate(keys=$keyString, functions=$functionString, output=$outputString)" + } else { + s"HashAggregate(keys=$keyString, functions=$functionString)" + } + case Some(fallbackStartsAt) => + s"HashAggregateWithControlledFallback $groupingExpressions " + + s"$allAggregateExpressions $resultExpressions fallbackStartsAt=$fallbackStartsAt" + } + } + + override protected def doProduce(ctx: CodegenContext): String = "" +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/aggregate/SimpleCountTungstenAggIter.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/aggregate/SimpleCountTungstenAggIter.scala new file mode 100644 index 0000000000000000000000000000000000000000..9f4fbffb612bf7e0bba0bebeb0ea334d8e8c5276 --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/aggregate/SimpleCountTungstenAggIter.scala @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter} +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.KVIterator + +/** + * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s. + * + * This iterator first uses hash-based aggregation to process input rows. It uses + * a hash map to store groups and their corresponding aggregation buffers. If + * this map cannot allocate memory from memory manager, it spills the map into disk + * and creates a new one. After processed all the input, then merge all the spills + * together using external sorter, and do sort-based aggregation. + * + * The process has the following step: + * - Step 0: Do hash-based aggregation. + * - Step 1: Sort all entries of the hash map based on values of grouping expressions and + * spill them to disk. + * - Step 2: Create an external sorter based on the spilled sorted map entries and reset the map. + * - Step 3: Get a sorted [[KVIterator]] from the external sorter. + * - Step 4: Repeat step 0 until no more input. + * - Step 5: Initialize sort-based aggregation on the sorted iterator. + * Then, this iterator works in the way of sort-based aggregation. + * + * The code of this class is organized as follows: + * - Part 1: Initializing aggregate functions. + * - Part 2: Methods and fields used by setting aggregation buffer values, + * processing input rows from inputIter, and generating output + * rows. + * - Part 3: Methods and fields used by hash-based aggregation. + * - Part 4: Methods and fields used when we switch to sort-based aggregation. + * - Part 5: Methods and fields used by sort-based aggregation. + * - Part 6: Loads input and process input rows. + * - Part 7: Public methods of this iterator. + * - Part 8: A utility function used to generate a result when there is no + * input and there is no grouping expression. + * + * @param partIndex + * index of the partition + * @param groupingExpressions + * expressions for grouping keys + * @param aggregateExpressions + * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Partial]], + * [[PartialMerge]], or [[Final]]. + * @param aggregateAttributes the attributes of the aggregateExpressions' + * outputs when they are stored in the final aggregation buffer. + * @param resultExpressions + * expressions for generating output rows. + * @param newMutableProjection + * the function used to create mutable projections. + * @param originalInputAttributes + * attributes of representing input rows from `inputIter`. + * @param inputIter + * the iterator containing input [[UnsafeRow]]s. + */ +class SimpleCountTungstenAggIter( + partIndex: Int, + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, + originalInputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow], + testFallbackStartsAt: Option[(Int, Int)], + numOutputRows: SQLMetric, + peakMemory: SQLMetric, + spillSize: SQLMetric, + avgHashProbe: SQLMetric, + shouldDistinctCount: Boolean = false, + DistinctCountValue: BigInt = -1) + extends AggregationIterator( + partIndex, + groupingExpressions, + originalInputAttributes, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection) with Logging { + + /////////////////////////////////////////////////////////////////////////// + // Part 1: Initializing aggregate functions. + /////////////////////////////////////////////////////////////////////////// + + // Remember spill data size of this task before execute this operator so that we can + // figure out how many bytes we spilled for this operator. + private val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled + + /////////////////////////////////////////////////////////////////////////// + // Part 2: Methods and fields used by setting aggregation buffer values, + // processing input rows from inputIter, and generating output + // rows. + /////////////////////////////////////////////////////////////////////////// + + // Creates a new aggregation buffer and initializes buffer values. + // This function should be only called at most two times (when we create the hash map, + // and when we create the re-used buffer for sort-based aggregation). + private def createNewAggregationBuffer(): UnsafeRow = { + val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) + val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType)) + .apply(new GenericInternalRow(bufferSchema.length)) + // Initialize declarative aggregates' buffer values + expressionAggInitialProjection.target(buffer)(EmptyRow) + // Initialize imperative aggregates' buffer values + aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) + buffer + } + + // Creates a function used to generate output rows. + override protected def generateResultProjection(): (UnsafeRow, InternalRow) => UnsafeRow = { + val modes = aggregateExpressions.map(_.mode).distinct + if (modes.nonEmpty && !modes.contains(Final) && !modes.contains(Complete)) { + // Fast path for partial aggregation, UnsafeRowJoiner is usually faster than projection + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + + (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { + unsafeRowJoiner.join(currentGroupingKey, currentBuffer.asInstanceOf[UnsafeRow]) + } + } else { + super.generateResultProjection() + } + } + + // An aggregation buffer containing initial buffer values. It is used to + // initialize other aggregation buffers. + private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + + /////////////////////////////////////////////////////////////////////////// + // Part 3: Methods and fields used by hash-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // This is the hash map used for hash-based aggregation. It is backed by an + // UnsafeFixedWidthAggregationMap and it is used to store + // all groups and their corresponding aggregation buffers for hash-based aggregation. + private[this] val hashMap = new UnsafeFixedWidthAggregationMap( + initialAggregationBuffer, + StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), + StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), + TaskContext.get(), + 1024 * 16, // initial capacity + TaskContext.get().taskMemoryManager().pageSizeBytes + ) + + // The function used to read and process input rows. When processing input rows, + // it first uses hash-based aggregation by putting groups and their buffers in + // hashMap. If there is not enough memory, it will multiple hash-maps, spilling + // after each becomes full then using sort to merge these spills, finally do sort + // based aggregation. + private def processInputs(fallbackStartsAt: (Int, Int)): Unit = { + if (groupingExpressions.isEmpty) { + // If there is no grouping expressions, we can just reuse the same buffer over and over again. + // Note that it would be better to eliminate the hash map entirely in the future. + val groupingKey = groupingProjection.apply(null) + val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + if (shouldDistinctCount) { + buffer.setLong(0, DistinctCountValue.toLong) + } else { + while (inputIter.hasNext) { + val newInput = inputIter.next() + buffer.setLong(0, buffer.getLong(0) + newInput.getLong(1)) + } + } + } else { + var i = 0 + while (inputIter.hasNext) { + val newInput = inputIter.next() + i += 1 + } + } + } + + // The iterator created from hashMap. It is used to generate output rows when we + // are using hash-based aggregation. + private[this] var aggregationBufferMapIterator: KVIterator[UnsafeRow, UnsafeRow] = null + + // Indicates if aggregationBufferMapIterator still has key-value pairs. + private[this] var mapIteratorHasNext: Boolean = false + + /////////////////////////////////////////////////////////////////////////// + // Part 4: Methods and fields used when we switch to sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // This sorter is used for sort-based aggregation. It is initialized as soon as + // we switch from hash-based to sort-based aggregation. Otherwise, it is not used. + private[this] var externalSorter: UnsafeKVExternalSorter = null + + /** + * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. + */ + private def switchToSortBasedAggregation(): Unit = { + logInfo("falling back to sort based aggregation.") + + // Basically the value of the KVIterator returned by externalSorter + // will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it. + val newExpressions = aggregateExpressions.map { + case agg @ AggregateExpression(_, Partial, _, _, _) => + agg.copy(mode = PartialMerge) + case agg @ AggregateExpression(_, Complete, _, _, _) => + agg.copy(mode = Final) + case other => other + } + val newFunctions = initializeAggregateFunctions(newExpressions, 0) + val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes) + sortBasedProcessRow = generateProcessRow(newExpressions, newFunctions, newInputAttributes) + + // Step 5: Get the sorted iterator from the externalSorter. + sortedKVIterator = externalSorter.sortedIterator() + + // Step 6: Pre-load the first key-value pair from the sorted iterator to make + // hasNext idempotent. + sortedInputHasNewGroup = sortedKVIterator.next() + + // Copy the first key and value (aggregation buffer). + if (sortedInputHasNewGroup) { + val key = sortedKVIterator.getKey + val value = sortedKVIterator.getValue + nextGroupingKey = key.copy() + currentGroupingKey = key.copy() + firstRowInNextGroup = value.copy() + } + + // Step 7: set sortBased to true. + sortBased = true + } + + /////////////////////////////////////////////////////////////////////////// + // Part 5: Methods and fields used by sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // Indicates if we are using sort-based aggregation. Because we first try to use + // hash-based aggregation, its initial value is false. + private[this] var sortBased: Boolean = false + + // The KVIterator containing input rows for the sort-based aggregation. It will be + // set in switchToSortBasedAggregation when we switch to sort-based aggregation. + private[this] var sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = null + + // The grouping key of the current group. + private[this] var currentGroupingKey: UnsafeRow = null + + // The grouping key of next group. + private[this] var nextGroupingKey: UnsafeRow = null + + // The first row of next group. + private[this] var firstRowInNextGroup: UnsafeRow = null + + // Indicates if we has new group of rows from the sorted input iterator. + private[this] var sortedInputHasNewGroup: Boolean = false + + // The aggregation buffer used by the sort-based aggregation. + private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + + // The function used to process rows in a group + private[this] var sortBasedProcessRow: (InternalRow, InternalRow) => Unit = null + + // Processes rows in the current group. It will stop when it find a new group. + private def processCurrentSortedGroup(): Unit = { + // First, we need to copy nextGroupingKey to currentGroupingKey. + currentGroupingKey.copyFrom(nextGroupingKey) + // Now, we will start to find all rows belonging to this group. + // We create a variable to track if we see the next group. + var findNextPartition = false + // firstRowInNextGroup is the first row of this group. We first process it. + sortBasedProcessRow(sortBasedAggregationBuffer, firstRowInNextGroup) + + // The search will stop when we see the next group or there is no + // input row left in the iter. + // Pre-load the first key-value pair to make the condition of the while loop + // has no action (we do not trigger loading a new key-value pair + // when we evaluate the condition). + var hasNext = sortedKVIterator.next() + while (!findNextPartition && hasNext) { + // Get the grouping key and value (aggregation buffer). + val groupingKey = sortedKVIterator.getKey + val inputAggregationBuffer = sortedKVIterator.getValue + + // Check if the current row belongs the current input row. + if (currentGroupingKey.equals(groupingKey)) { + sortBasedProcessRow(sortBasedAggregationBuffer, inputAggregationBuffer) + + hasNext = sortedKVIterator.next() + } else { + // We find a new group. + findNextPartition = true + // copyFrom will fail when + nextGroupingKey.copyFrom(groupingKey) + firstRowInNextGroup.copyFrom(inputAggregationBuffer) + } + } + // We have not seen a new group. It means that there is no new row in the input + // iter. The current group is the last group of the sortedKVIterator. + if (!findNextPartition) { + sortedInputHasNewGroup = false + sortedKVIterator.close() + } + } + + /////////////////////////////////////////////////////////////////////////// + // Part 6: Loads input rows and setup aggregationBufferMapIterator if we + // have not switched to sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + /** + * Start processing input rows. + */ + processInputs(testFallbackStartsAt.getOrElse((Int.MaxValue, Int.MaxValue))) + + // If we did not switch to sort-based aggregation in processInputs, + // we pre-load the first key-value pair from the map (to make hasNext idempotent). + if (!sortBased) { + // First, set aggregationBufferMapIterator. + aggregationBufferMapIterator = hashMap.iterator() + // Pre-load the first key-value pair from the aggregationBufferMapIterator. + mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!mapIteratorHasNext) { + hashMap.free() + } + } + + TaskContext.get().addTaskCompletionListener[Unit](_ => { + // At the end of the task, update the task's peak memory usage. Since we destroy + // the map to create the sorter, their memory usages should not overlap, so it is safe + // to just use the max of the two. + val mapMemory = hashMap.getPeakMemoryUsedBytes + val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) + val maxMemory = Math.max(mapMemory, sorterMemory) + val metrics = TaskContext.get().taskMetrics() + peakMemory.set(maxMemory) + spillSize.set(metrics.memoryBytesSpilled - spillSizeBefore) + metrics.incPeakExecutionMemory(maxMemory) + + // Updating average hashmap probe + avgHashProbe.set(hashMap.getAvgHashProbeBucketListIterations) + }) + + /////////////////////////////////////////////////////////////////////////// + // Part 7: Iterator's public methods. + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = { + (sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext) + } + + override final def next(): UnsafeRow = { + if (hasNext) { + val res = if (sortBased) { + // Process the current group. + processCurrentSortedGroup() + // Generate output row for the current group. + val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer) + // Initialize buffer values for the next group. + sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) + + outputRow + } else { + // We did not fall back to sort-based aggregation. + val result = + generateOutput( + aggregationBufferMapIterator.getKey, + aggregationBufferMapIterator.getValue) + + // Pre-load next key-value pair form aggregationBufferMapIterator to make hasNext + // idempotent. + mapIteratorHasNext = aggregationBufferMapIterator.next() + + if (!mapIteratorHasNext) { + // If there is no input from aggregationBufferMapIterator, we copy current result. + val resultCopy = result.copy() + // Then, we free the map. + hashMap.free() + + resultCopy + } else { + result + } + } + + numOutputRows += 1 + res + } else { + // no more result + throw new NoSuchElementException + } + } + + /////////////////////////////////////////////////////////////////////////// + // Part 8: Utility functions + /////////////////////////////////////////////////////////////////////////// + + /** + * Generate an output row when there is no input and there is no grouping expression. + */ + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { + if (groupingExpressions.isEmpty) { + sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) + // We create an output row and copy it. So, we can free the map. + val resultCopy = + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() + hashMap.free() + resultCopy + } else { + throw new IllegalStateException( + "This method should not be called when groupingExpressions is not empty.") + } + } +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/command/NdpAnalyzeColumnCommand.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/command/NdpAnalyzeColumnCommand.scala new file mode 100644 index 0000000000000000000000000000000000000000..20311b0cb909edc75d607eee51397b41e9ddf1f1 --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/command/NdpAnalyzeColumnCommand.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTableType} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} +import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +import java.time.ZoneOffset + + +/** + * Analyzes the given columns of the given table to generate statistics, which will be used in + * query optimizations. Parameter `allColumns` may be specified to generate statistics of all the + * columns of a given table. + */ +case class NdpAnalyzeColumnCommand( + tableIdent: TableIdentifier, + columnNames: Option[Seq[String]], + allColumns: Boolean) extends RunnableCommand { + + override def run(sparkSession: SparkSession): Seq[Row] = { + require(columnNames.isDefined ^ allColumns, "Parameter `columnNames` or `allColumns` are " + + "mutually exclusive. Only one of them should be specified.") + val sessionState = sparkSession.sessionState + + tableIdent.database match { + case Some(db) if db == sparkSession.sharedState.globalTempViewManager.database => + val plan = sessionState.catalog.getGlobalTempView(tableIdent.identifier).getOrElse { + throw new NoSuchTableException(db = db, table = tableIdent.identifier) + } + analyzeColumnInTempView(plan, sparkSession) + case Some(_) => + analyzeColumnInCatalog(sparkSession) + case None => + sessionState.catalog.getTempView(tableIdent.identifier) match { + case Some(tempView) => analyzeColumnInTempView(tempView, sparkSession) + case _ => analyzeColumnInCatalog(sparkSession) + } + } + + Seq.empty[Row] + } + + private def analyzeColumnInCachedData(plan: LogicalPlan, sparkSession: SparkSession): Boolean = { + val cacheManager = sparkSession.sharedState.cacheManager + val planToLookup = sparkSession.sessionState.executePlan(plan).analyzed + cacheManager.lookupCachedData(planToLookup).map { cachedData => + val columnsToAnalyze = getColumnsToAnalyze( + tableIdent, cachedData.cachedRepresentation, columnNames, allColumns) + cacheManager.analyzeColumnCacheQuery(sparkSession, cachedData, columnsToAnalyze) + cachedData + }.isDefined + } + + private def analyzeColumnInTempView(plan: LogicalPlan, sparkSession: SparkSession): Unit = { + if (!analyzeColumnInCachedData(plan, sparkSession)) { + throw new AnalysisException( + s"Temporary view $tableIdent is not cached for analyzing columns.") + } + } + + private def getColumnsToAnalyze( + tableIdent: TableIdentifier, + relation: LogicalPlan, + columnNames: Option[Seq[String]], + allColumns: Boolean = false): Seq[Attribute] = { + val columnsToAnalyze = if (allColumns) { + relation.output + } else { + columnNames.get.map { col => + val exprOption = relation.output.find(attr => conf.resolver(attr.name, col)) + exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist.")) + } + } + // Make sure the column types are supported for stats gathering. + columnsToAnalyze.foreach { attr => + if (!supportsType(attr.dataType)) { + throw new AnalysisException( + s"Column ${attr.name} in table $tableIdent is of type ${attr.dataType}, " + + "and Spark does not support statistics collection on this column type.") + } + } + columnsToAnalyze + } + + private def analyzeColumnInCatalog(sparkSession: SparkSession): Unit = { + val sessionState = sparkSession.sessionState + val tableMeta = sessionState.catalog.getTableMetadata(tableIdent) + if (tableMeta.tableType == CatalogTableType.VIEW) { + // Analyzes a catalog view if the view is cached + val plan = sparkSession.table(tableIdent.quotedString).logicalPlan + if (!analyzeColumnInCachedData(plan, sparkSession)) { + throw new AnalysisException("ANALYZE TABLE is not supported on views.") + } + } else { + val sizeInBytes = CommandUtils.calculateTotalSize(sparkSession, tableMeta) + val relation = sparkSession.table(tableIdent).logicalPlan + val columnsToAnalyze = getColumnsToAnalyze(tableIdent, relation, columnNames, allColumns) + + SQLConf.get.setConfString("spark.omni.sql.ndpPlugin.castDecimal.enabled", "false") + // Compute stats for the computed list of columns. + val (rowCount, newColStats) = + NdpCommandUtils.computeColumnStats(sparkSession, relation, columnsToAnalyze) + SQLConf.get.setConfString("spark.omni.sql.ndpPlugin.castDecimal.enabled", "true") + val newColCatalogStats = newColStats.map { + case (attr, columnStat) => + attr.name -> toCatalogColumnStat(columnStat, attr.name, attr.dataType) + } + + // We also update table-level stats in order to keep them consistent with column-level stats. + val statistics = CatalogStatistics( + sizeInBytes = sizeInBytes, + rowCount = Some(rowCount), + // Newly computed column stats should override the existing ones. + colStats = tableMeta.stats.map(_.colStats).getOrElse(Map.empty) ++ newColCatalogStats) + + sessionState.catalog.alterTableStats(tableIdent, Some(statistics)) + } + } + + private def toCatalogColumnStat(columnStat: ColumnStat, colName: String, dataType: DataType): CatalogColumnStat = + CatalogColumnStat( + distinctCount = columnStat.distinctCount, + min = columnStat.min.map(toExternalString(_, colName, dataType)), + max = columnStat.max.map(toExternalString(_, colName, dataType)), + nullCount = columnStat.nullCount, + avgLen = columnStat.avgLen, + maxLen = columnStat.maxLen, + histogram = columnStat.histogram, + version = columnStat.version) + + private def toExternalString(v: Any, colName: String, dataType: DataType): String = { + val externalValue = dataType match { + case DateType => DateFormatter(ZoneOffset.UTC).format(v.asInstanceOf[Int]) + case TimestampType => getTimestampFormatter(isParsing = false).format(v.asInstanceOf[Long]) + case BooleanType | _: IntegralType | FloatType | DoubleType | StringType => v + case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal + case _ => + throw new AnalysisException("Column statistics serialization is not supported for " + + s"column $colName of data type: $dataType.") + } + externalValue.toString + } + + private def getTimestampFormatter(isParsing: Boolean): TimestampFormatter = { + TimestampFormatter( + format = "yyyy-MM-dd HH:mm:ss.SSSSSS", + zoneId = ZoneOffset.UTC, + isParsing = isParsing) + } + + /** Returns true iff the we support gathering column statistics on column of the given type. */ + private def supportsType(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case BooleanType => true + case DateType => true + case TimestampType => true + case BinaryType | StringType => true + case _ => false + } +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/command/NdpCommandUtils.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/command/NdpCommandUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..1d63ab4dac56a72fcc0aab81a5291998afe496c5 --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/command/NdpCommandUtils.scala @@ -0,0 +1,203 @@ +package org.apache.spark.sql.execution.command + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.functions.countDistinct +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +object NdpCommandUtils extends Logging { + + private[sql] def computeColumnStats( + sparkSession: SparkSession, + relation: LogicalPlan, + columns: Seq[Attribute]): (Long, Map[Attribute, ColumnStat]) = { + val conf = sparkSession.sessionState.conf + + // Collect statistics per column. + // If no histogram is required, we run a job to compute basic column stats such as + // min, max, ndv, etc. Otherwise, besides basic column stats, histogram will also be + // generated. Currently we only support equi-height histogram. + // To generate an equi-height histogram, we need two jobs: + // 1. compute percentiles p(0), p(1/n) ... p((n-1)/n), p(1). + // 2. use the percentiles as value intervals of bins, e.g. [p(0), p(1/n)], + // [p(1/n), p(2/n)], ..., [p((n-1)/n), p(1)], and then count ndv in each bin. + // Basic column stats will be computed together in the second job. + val attributePercentiles = computePercentiles(columns, sparkSession, relation) + + // The first element in the result will be the overall row count, the following elements + // will be structs containing all column stats. + // The layout of each struct follows the layout of the ColumnStats. + val expressions = Count(Literal(1)).toAggregateExpression() +: + columns.map(statExprs(_, conf, attributePercentiles)) + + val namedExpressions = expressions.map(e => Alias(e, e.toString)()) + val statsRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExpressions, relation)) + .executedPlan.executeTake(1).head + + val rowCount = statsRow.getLong(0) + val columnStats = columns.zipWithIndex.map { case (attr, i) => + // according to `statExprs`, the stats struct always have 7 fields. + (attr, rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, + attributePercentiles.get(attr))) + }.toMap + (rowCount, columnStats) + } + + /** Computes percentiles for each attribute. */ + private def computePercentiles( + attributesToAnalyze: Seq[Attribute], + sparkSession: SparkSession, + relation: LogicalPlan): AttributeMap[ArrayData] = { + val conf = sparkSession.sessionState.conf + val attrsToGenHistogram = if (conf.histogramEnabled) { + attributesToAnalyze.filter(a => supportsHistogram(a.dataType)) + } else { + Nil + } + val attributePercentiles = mutable.HashMap[Attribute, ArrayData]() + if (attrsToGenHistogram.nonEmpty) { + val percentiles = (0 to conf.histogramNumBins) + .map(i => i.toDouble / conf.histogramNumBins).toArray + + val namedExprs = attrsToGenHistogram.map { attr => + val aggFunc = + new ApproximatePercentile(attr, + Literal(new GenericArrayData(percentiles), ArrayType(DoubleType, false)), + Literal(conf.percentileAccuracy)) + val expr = aggFunc.toAggregateExpression() + Alias(expr, expr.toString)() + } + + val percentilesRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExprs, relation)) + .executedPlan.executeTake(1).head + attrsToGenHistogram.zipWithIndex.foreach { case (attr, i) => + val percentiles = percentilesRow.getArray(i) + // When there is no non-null value, `percentiles` is null. In such case, there is no + // need to generate histogram. + if (percentiles != null) { + attributePercentiles += attr -> percentiles + } + } + } + AttributeMap(attributePercentiles.toSeq) + } + + + /** Returns true iff the we support gathering histogram on column of the given type. */ + private def supportsHistogram(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case DateType => true + case TimestampType => true + case _ => false + } + + /** + * Constructs an expression to compute column statistics for a given column. + * + * The expression should create a single struct column with the following schema: + * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long, + * distinctCountsForIntervals: Array[Long] + * + * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and + * as a result should stay in sync with it. + */ + private def statExprs( + col: Attribute, + conf: SQLConf, + colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = { + def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => + expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } + }) + + val one = Literal(1.toLong, LongType) + + // the approximate ndv (num distinct value) should never be larger than the number of rows + val numNonNulls = if (col.nullable) Count(col) else Count(one) + val ndv = countDistinct(col.name).expr + val numNulls = Subtract(Count(one), numNonNulls) + val defaultSize = Literal(col.dataType.defaultSize.toLong, LongType) + val nullArray = Literal(null, ArrayType(LongType)) + + def fixedLenTypeStruct: CreateNamedStruct = { + val genHistogram = + supportsHistogram(col.dataType) && colPercentiles.contains(col) + val intervalNdvsExpr = if (genHistogram) { + ApproxCountDistinctForIntervals(col, + Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError) + } else { + nullArray + } + // For fixed width types, avg size should be the same as max size. + struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls, + defaultSize, defaultSize, intervalNdvsExpr) + } + + col.dataType match { + case _: IntegralType => fixedLenTypeStruct + case _: DecimalType => fixedLenTypeStruct + case DoubleType | FloatType => fixedLenTypeStruct + case BooleanType => fixedLenTypeStruct + case DateType => fixedLenTypeStruct + case TimestampType => fixedLenTypeStruct + case StringType => fixedLenTypeStruct + case BinaryType => + // For binary type, we don't compute min, max or histogram + val nullLit = Literal(null, col.dataType) + struct( + ndv, nullLit, nullLit, numNulls, + // Set avg/max size to default size if all the values are null or there is no value. + Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), + Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)), + nullArray) + case _ => + throw new AnalysisException("Analyzing column statistics is not supported for column " + + s"${col.name} of data type: ${col.dataType}.") + } + } + + /** + * Convert a struct for column stats (defined in `statExprs`) into + * [[org.apache.spark.sql.catalyst.plans.logical.ColumnStat]]. + */ + private def rowToColumnStat( + row: InternalRow, + attr: Attribute, + rowCount: Long, + percentiles: Option[ArrayData]): ColumnStat = { + // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins. + val cs = ColumnStat( + distinctCount = Option(BigInt(row.getLong(0))), + // for string/binary min/max, get should return null + min = Option(row.get(1, attr.dataType)), + max = Option(row.get(2, attr.dataType)), + nullCount = Option(BigInt(row.getLong(3))), + avgLen = Option(row.getLong(4)), + maxLen = Option(row.getLong(5)) + ) + if (row.isNullAt(6) || cs.nullCount.isEmpty) { + cs + } else { + val ndvs = row.getArray(6).toLongArray() + assert(percentiles.get.numElements() == ndvs.length + 1) + val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble) + // Construct equi-height histogram + val bins = ndvs.zipWithIndex.map { case (ndv, i) => + HistogramBin(endpoints(i), endpoints(i + 1), ndv) + } + val nonNullRows = rowCount - cs.nullCount.get + val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins) + cs.copy(histogram = Some(histogram)) + } + } +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index ccbe1c5558de81ad769cda68403ec1bbab9f3406..4869c6f466ff7f19219ec8d4e975d8568c73cd96 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -18,11 +18,8 @@ package org.apache.spark.sql.execution.datasources import java.util.Locale - import scala.collection.mutable - import org.apache.hadoop.fs.Path - import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ @@ -34,7 +31,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter => LFilter, InsertIntoDir, InsertIntoStatement, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project, Filter => LFilter} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 @@ -42,6 +39,7 @@ import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.execution.ndp.NdpFilterEstimation import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.sources._ @@ -422,7 +420,7 @@ object DataSourceStrategy relation.relation, relation.catalogTable.map(_.identifier)) filterCondition.map{ x => - val selectivity = FilterEstimation(LFilter(x, relation)).calculateFilterSelectivity(x) + val selectivity = NdpFilterEstimation(FilterEstimation(LFilter(x, relation))).calculateFilterSelectivity(x) execution.FilterExec(x, scan, selectivity) }.getOrElse(scan) } else { @@ -448,7 +446,7 @@ object DataSourceStrategy relation.catalogTable.map(_.identifier)) execution.ProjectExec( projects, filterCondition.map{x => - val selectivity = FilterEstimation(LFilter(x, relation)).calculateFilterSelectivity(x) + val selectivity = NdpFilterEstimation(FilterEstimation(LFilter(x, relation))).calculateFilterSelectivity(x) execution.FilterExec(x, scan, selectivity) }.getOrElse(scan)) } diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDDPushDown.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDDPushDown.scala index c55ed09a0b0c25588ac5bd1e2878ba2b5c9b2709..fadb8eae150bf6da38421be1a19a12b9aeea6cc4 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDDPushDown.scala +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDDPushDown.scala @@ -17,36 +17,57 @@ package org.apache.spark.sql.execution.datasources -import java.util +import com.google.common.collect.ImmutableMap +import com.huawei.boostkit.omnidata.exception.OmniDataException + +import java.util import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.parquet.io.ParquetDecodingException import org.apache.spark.{SparkUpgradeException, TaskContext, Partition => RDDPartition} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.InputMetrics import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.{DataIoAdapter, NdpUtils, PageCandidate, PageToColumnar, PushDownManager, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.execution.QueryExecutionException -import org.apache.spark.sql.execution.ndp.{NdpConf, PushDownInfo} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, BasePredicate, Expression, Predicate, UnsafeProjection} +import org.apache.spark.sql.execution.ndp.NdpSupport.filterStripEnd +import org.apache.spark.sql.execution.{QueryExecutionException, RowToColumnConverter} +import org.apache.spark.sql.execution.ndp.{FilterExeInfo, NdpConf, PushDownInfo} +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector, WritableColumnVector} +import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.NextIterator + +import java.io.{FileNotFoundException, IOException} +import scala.util.Random /** * An RDD that scans a list of file partitions. */ class FileScanRDDPushDown( - @transient private val sparkSession: SparkSession, - @transient val filePartitions: Seq[FilePartition], - requiredSchema: StructType, - output: Seq[Attribute], - dataSchema: StructType, - pushDownOperators: PushDownInfo, - partitionColumns: Seq[Attribute], - isColumnVector: Boolean, - fileFormat: FileFormat) + @transient private val sparkSession: SparkSession, + @transient val filePartitions: Seq[FilePartition], + requiredSchema: StructType, + output: Seq[Attribute], + dataSchema: StructType, + pushDownOperators: PushDownInfo, + partitionColumns: Seq[Attribute], + isColumnVector: Boolean, + fileFormat: FileFormat, + readFunction: (PartitionedFile) => Iterator[InternalRow], + partialCondition: Boolean, + partialPdRate: Double, + zkPdRate: Double, + partialChildOutput: Seq[Attribute], + isOptimizerPushDown: Boolean = false, + pushDownTotal: Int, + taskTotal: Int, + perThreadTask: Int = 1, + omniGroupId: String) extends RDD[InternalRow](sparkSession.sparkContext, Nil) { var columnOffset = -1 @@ -63,7 +84,7 @@ class FileScanRDDPushDown( columnOffset = NdpUtils.getColumnOffset(dataSchema, output) filterOutput = output } - var fpuMap = pushDownOperators.fpuHosts + var fpuMap = pushDownOperators.fpuHosts.map(term => (term._2, term._1)) var fpuList : Seq[String] = Seq() for (key <- fpuMap.keys) { fpuList = fpuList :+ key @@ -81,19 +102,148 @@ class FileScanRDDPushDown( scala.collection.mutable.Map[String, scala.collection.mutable.Map[String, Seq[Expression]]]() var projectId = 0 val expressions: util.ArrayList[Object] = new util.ArrayList[Object]() + val enableOffHeapColumnVector: Boolean = sparkSession.sessionState.conf.offHeapColumnVectorEnabled + val columnBatchSize: Int = sparkSession.sessionState.conf.columnBatchSize + val converters = new RowToColumnConverter(StructType.fromAttributes(output)) private val timeOut = NdpConf.getNdpZookeeperTimeout(sparkSession) private val parentPath = NdpConf.getNdpZookeeperPath(sparkSession) private val zkAddress = NdpConf.getNdpZookeeperAddress(sparkSession) private val taskTimeout = NdpConf.getTaskTimeout(sparkSession) + private val operatorCombineEnabled = NdpConf.getNdpOperatorCombineEnabled(sparkSession) + val orcImpl: String = sparkSession.sessionState.conf.getConf(ORC_IMPLEMENTATION) + + private val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + private val ignoreMissingFiles = sparkSession.sessionState.conf.ignoreMissingFiles + + var pushDownIterator : PushDownIterator = _ + var forceOmniDataPushDown : Boolean = false + var isFirstOptimizerPushDown : Boolean = true override def compute(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = { + if(isOptimizerPushDown){ + logDebug("optimizer push down") + computeSparkRDDAndOptimizerPushDown(split, context) + } else { + logDebug("Really push down") + computePushDownRDD(split, context) + } + } + + def computePushDownRDD(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = { val pageToColumnarClass = new PageToColumnar(requiredSchema, output) + if (!forceOmniDataPushDown && isPartialPushDown(partialCondition, partialPdRate, zkPdRate)) { + logDebug("partial push down task on spark") + val partialFilterCondition = pushDownOperators.filterExecutions.reduce((a, b) => FilterExeInfo(And(a.filter, b.filter), partialChildOutput)) + var partialFilter : Expression = null + if (orcImpl.equals("hive")) { + partialFilter = partialFilterCondition.filter + } else { + partialFilter = filterStripEnd(partialFilterCondition.filter) + } + val predicate = Predicate.create(partialFilter, partialChildOutput) + predicate.initialize(0) + pushDownIterator = new PartialPushDownIterator(split, context, pageToColumnarClass, predicate) + } else { + logDebug("partial push down task on omnidata") + pushDownIterator = new PushDownIterator(split, context, pageToColumnarClass) + } + // Register an on-task-completion callback to close the input stream. + context.addTaskCompletionListener[Unit](_ => pushDownIterator.close()) + + pushDownIterator.asInstanceOf[Iterator[InternalRow]] // This is an erasure hack. + } + + class OptimizerPushDownThread(sparkThread: Thread, + splits: Array[RDDPartition], + context: TaskContext, + scan : FileScanRDDPushDown, + sparkLog : org.slf4j.Logger) extends Thread { + scan.forceOmniDataPushDown = true + override def run(): Unit = { + var i: Int = 0 + try { + while (!context.isCompleted() && sparkThread.isAlive && i < splits.length) { + val iter: Iterator[Any] = scan.computePushDownRDD(splits(i), context) + i = i + 1 + while (!context.isCompleted() && sparkThread.isAlive && iter.hasNext) { + sparkLog.debug(">>>>>>optimizer push down Thread [running]>>>>>") + val currentValue = iter.next() + currentValue match { + case batch: ColumnarBatch => batch.close() + case _ => + } + } + } + } catch { + case e: Exception => + sparkLog.debug("Optimizer push down thread has Interrupted:", e) + } finally { + sparkLog.debug(">>>>>>optimizer push down Thread [end]>>>>>") + scan.pushDownIterator.close() + scan.pushDownIterator.dataIoClass.close() + sparkLog.debug("pushDownIterator close") + this.interrupt() + } + } + } + + var threadPushDownCount:Int = 0 + var pushSplits: Array[RDDPartition] = Array() + var loopTimes = 0 + + def doOptimizerPush(split: RDDPartition, context: TaskContext, scan: FileScanRDDPushDown): Unit = { + val uniqueID = context.taskAttemptId() + val partID = context.partitionId() + val taskSizeD = taskTotal.toDouble + val taskSpace = Math.max(Math.ceil(taskSizeD/pushDownTotal).toInt, 1) + log.debug("OptimizerPush info uniqueID: " + uniqueID + ",partID: " + partID + ",push" + + "DownTotal: " + pushDownTotal + ",taskTotal: " + taskTotal + ",taskSpace: " + taskSpace) + var pushDownRDDPartition = split + split match { + case filePartition: FilePartition => + val files: Array[PartitionedFile] = Array(filePartition.files.head) + pushDownRDDPartition = new FilePartition(filePartition.index, files, filePartition.sdi) + pushSplits = pushSplits :+ pushDownRDDPartition + case _ => + } + + loopTimes = loopTimes + 1 + if (loopTimes < perThreadTask) { + log.debug("pushSplits need add") + return + } + + if (loopTimes > perThreadTask) { + log.debug("pushSplits full") + return + } + if (uniqueID % taskSpace == 0) { + log.debug("do optimizer push down RDD") + val pushDownThread = new OptimizerPushDownThread(Thread.currentThread(), pushSplits, context, scan, log) + pushDownThread.start() + } else { + log.debug("do spark push down RDD") + } + + } + + def computeSparkRDDAndOptimizerPushDown(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = { + //this code (computeSparkRDDAndOptimizerPushDown) from spark FileScanRDD + doOptimizerPush(split, context, this) val iterator = new Iterator[Object] with AutoCloseable { private val inputMetrics = context.taskMetrics().inputMetrics private val existingBytesRead = inputMetrics.bytesRead + + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // apply readFunction, because it might read some bytes. private val getBytesReadCallback = SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + + // We get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). private def incTaskInputMetricsBytesRead(): Unit = { inputMetrics.setBytesRead(existingBytesRead + getBytesReadCallback()) } @@ -101,40 +251,19 @@ class FileScanRDDPushDown( private[this] val files = split.asInstanceOf[FilePartition].files.toIterator private[this] var currentFile: PartitionedFile = null private[this] var currentIterator: Iterator[Object] = null - private[this] val sdiHosts = split.asInstanceOf[FilePartition].sdi - val dataIoClass = new DataIoAdapter() def hasNext: Boolean = { // Kill the task in case it has been marked as killed. This logic is from // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order // to avoid performance overhead. context.killTaskIfInterrupted() - val hasNext = currentIterator != null && currentIterator.hasNext - if (hasNext) { - hasNext - } else { - val tmp: util.ArrayList[Object] = new util.ArrayList[Object]() - var hasnextIterator = false - try { - hasnextIterator = dataIoClass.hasNextIterator(tmp, pageToColumnarClass, - currentFile, isColumnVector) - } catch { - case e : Exception => - throw e - } - val ret = if (hasnextIterator && tmp.size() > 0) { - currentIterator = tmp.asScala.iterator - hasnextIterator - } else { - nextIterator() - } - ret - } + (currentIterator != null && currentIterator.hasNext) || nextIterator() } def next(): Object = { val nextElement = currentIterator.next() // TODO: we should have a better separation of row based and batch based scan, so that we // don't need to run this `if` for every record. + val preNumRecordsRead = inputMetrics.recordsRead if (nextElement.isInstanceOf[ColumnarBatch]) { incTaskInputMetricsBytesRead() inputMetrics.incRecordsRead(nextElement.asInstanceOf[ColumnarBatch].numRows()) @@ -149,19 +278,64 @@ class FileScanRDDPushDown( nextElement } + private def readCurrentFile(): Iterator[InternalRow] = { + try { + readFunction(currentFile) + } catch { + case e: FileNotFoundException => + throw new FileNotFoundException( + e.getMessage + "\n" + + "It is possible the underlying files have been updated. " + + "You can explicitly invalidate the cache in Spark by " + + "running 'REFRESH TABLE tableName' command in SQL or " + + "by recreating the Dataset/DataFrame involved.") + } + } + /** Advances to the next file. Returns true if a new non-empty iterator is available. */ private def nextIterator(): Boolean = { if (files.hasNext) { currentFile = files.next() - // logInfo(s"Reading File $currentFile") + logInfo(s"Reading File $currentFile") + // Sets InputFileBlockHolder for the file block's information InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length) - val pageCandidate = new PageCandidate(currentFile.filePath, currentFile.start, - currentFile.length, columnOffset, sdiHosts, - fileFormat.toString, maxFailedTimes, taskTimeout) - val dataIoPage = dataIoClass.getPageIterator(pageCandidate, output, - partitionColumns, filterOutput, pushDownOperators) - currentIterator = pageToColumnarClass.transPageToColumnar(dataIoPage, - isColumnVector).asScala.iterator + + if (ignoreMissingFiles || ignoreCorruptFiles) { + currentIterator = new NextIterator[Object] { + // The readFunction may read some bytes before consuming the iterator, e.g., + // vectorized Parquet reader. Here we use lazy val to delay the creation of + // iterator so that we will throw exception in `getNext`. + private lazy val internalIter = readCurrentFile() + + override def getNext(): AnyRef = { + try { + if (internalIter.hasNext) { + internalIter.next() + } else { + finished = true + null + } + } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: $currentFile", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e + case e @ (_: RuntimeException | _: IOException) if ignoreCorruptFiles => + logWarning( + s"Skipped the rest of the content in the corrupted file: $currentFile", e) + finished = true + null + } + } + + override def close(): Unit = {} + } + } else { + currentIterator = readCurrentFile() + } + try { hasNext } catch { @@ -200,7 +374,24 @@ class FileScanRDDPushDown( iterator.asInstanceOf[Iterator[InternalRow]] // This is an erasure hack. } + def isPartialPushDown(partialCondition: Boolean, partialPdRate: Double, zkPdRate: Double): Boolean = { + var res = false + val randomNum = Random.nextDouble; + if (partialCondition && (randomNum > partialPdRate || randomNum > zkPdRate)) { + res = true + } + res + } + override protected def getPartitions: Array[RDDPartition] = { + if(isOptimizerPushDown) { + getSparkPartitions + } else { + getPushDownPartitions + } + } + + def getPushDownPartitions: Array[RDDPartition] = { filePartitions.map { partitionFile => { val retHost = mutable.HashMap.empty[String, Long] partitionFile.files.foreach { partitionMap => { @@ -208,14 +399,19 @@ class FileScanRDDPushDown( sdiKey => { retHost(sdiKey) = retHost.getOrElse(sdiKey, 0L) + partitionMap.length sdiKey - }} + }} }} val datanode = retHost.toSeq.sortWith((x, y) => x._2 > y._2).toIterator var mapNum = 0 if (fpuMap == null) { val pushDownManagerClass = new PushDownManager() - fpuMap = pushDownManagerClass.getZookeeperData(timeOut, parentPath, zkAddress) + val fMap = pushDownManagerClass.getZookeeperData(timeOut, parentPath, zkAddress) + val hostMap = mutable.Map[String,String]() + for (kv <- fMap) { + hostMap.put(kv._1, kv._2.getDatanodeHost) + } + fpuMap = hostMap } while (datanode.hasNext && mapNum < maxFailedTimes) { val datanodeStr = datanode.next()._1 @@ -240,7 +436,207 @@ class FileScanRDDPushDown( filePartitions.toArray } + def getSparkPartitions: Array[RDDPartition] = filePartitions.toArray + override protected def getPreferredLocations(split: RDDPartition): Seq[String] = { split.asInstanceOf[FilePartition].preferredLocations() } + + class PushDownIterator(split: RDDPartition, + context: TaskContext, + pageToColumnarClass: PageToColumnar) + extends Iterator[Object] with AutoCloseable { + + val inputMetrics: InputMetrics = context.taskMetrics().inputMetrics + val existingBytesRead: Long = inputMetrics.bytesRead + val getBytesReadCallback: () => Long = + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + def incTaskInputMetricsBytesRead(): Unit = { + inputMetrics.setBytesRead(existingBytesRead + getBytesReadCallback()) + } + + val files: Iterator[PartitionedFile] = split.asInstanceOf[FilePartition].files.toIterator + var currentFile: PartitionedFile = null + var currentIterator: Iterator[Object] = null + val sdiHosts: String = split.asInstanceOf[FilePartition].sdi + val dataIoClass = new DataIoAdapter() + val domains: ImmutableMap[_, _] = dataIoClass.buildDomains(output,partitionColumns, filterOutput, + pushDownOperators, context) + + def hasNext: Boolean = { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. + context.killTaskIfInterrupted() + val hasNext = currentIterator != null && currentIterator.hasNext + if (hasNext) { + hasNext + } else { + val tmp: util.ArrayList[Object] = new util.ArrayList[Object]() + var hasnextIterator = false + try { + hasnextIterator = dataIoClass.hasNextIterator(tmp, pageToColumnarClass, isColumnVector, output, orcImpl) + } catch { + case e : Exception => + throw e + } + val ret = if (hasnextIterator && tmp.size() > 0) { + currentIterator = tmp.asScala.iterator + hasnextIterator + } else { + nextIterator() + } + ret + } + } + def next(): Object = { + val nextElement = currentIterator.next() + // TODO: we should have a better separation of row based and batch based scan, so that we + // don't need to run this `if` for every record. + if (nextElement.isInstanceOf[ColumnarBatch]) { + incTaskInputMetricsBytesRead() + inputMetrics.incRecordsRead(nextElement.asInstanceOf[ColumnarBatch].numRows()) + } else { + // too costly to update every record + if (inputMetrics.recordsRead % + SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + incTaskInputMetricsBytesRead() + } + inputMetrics.incRecordsRead(1) + } + nextElement + } + + /** Advances to the next file. Returns true if a new non-empty iterator is available. */ + def nextIterator(): Boolean = { + if (files.hasNext) { + currentFile = files.next() + // logInfo(s"Reading File $currentFile") + InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length) + val pageCandidate = new PageCandidate(currentFile.filePath, currentFile.start, + currentFile.length, columnOffset, sdiHosts, + fileFormat.toString, maxFailedTimes, taskTimeout,operatorCombineEnabled) + val dataIoPage = dataIoClass.getPageIterator(pageCandidate, output, + partitionColumns, filterOutput, pushDownOperators, domains, isColumnVector, omniGroupId) + currentIterator = pageToColumnarClass.transPageToColumnar(dataIoPage, + isColumnVector, dataIoClass.isOperatorCombineEnabled, output, orcImpl).asScala.iterator + iteHasNext() + } else { + unset() + } + } + + def iteHasNext(): Boolean = { + try { + hasNext + } catch { + case e: SchemaColumnConvertNotSupportedException => + val message = "Parquet column cannot be converted in " + + s"file ${currentFile.filePath}. Column: ${e.getColumn}, " + + s"Expected: ${e.getLogicalType}, Found: ${e.getPhysicalType}" + throw new QueryExecutionException(message, e) + case e: ParquetDecodingException => + if (e.getCause.isInstanceOf[SparkUpgradeException]) { + throw e.getCause + } else if (e.getMessage.contains("Can not read value at")) { + val message = "Encounter error while reading parquet files. " + + "One possible cause: Parquet column cannot be converted in the " + + "corresponding files. Details: " + throw new QueryExecutionException(message, e) + } + throw e + } + } + + def unset(): Boolean = { + currentFile = null + InputFileBlockHolder.unset() + false + } + + override def close(): Unit = { + incTaskInputMetricsBytesRead() + InputFileBlockHolder.unset() + } + } + + class PartialPushDownIterator(split: RDDPartition, + context: TaskContext, + pageToColumnarClass: PageToColumnar, + predicate: BasePredicate) + extends PushDownIterator(split: RDDPartition, context: TaskContext, pageToColumnarClass: PageToColumnar) { + + val vectors: Seq[WritableColumnVector] = if (enableOffHeapColumnVector) { + OffHeapColumnVector.allocateColumns(columnBatchSize, StructType.fromAttributes(output)) + } else { + OnHeapColumnVector.allocateColumns(columnBatchSize, StructType.fromAttributes(output)) + } + val cb: ColumnarBatch = new ColumnarBatch(vectors.toArray) + + TaskContext.get().addTaskCompletionListener[Unit] { _ => + cb.close() + } + + override def hasNext: Boolean = { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. + context.killTaskIfInterrupted() + (currentIterator != null && currentIterator.hasNext) || nextIterator() + } + + override def nextIterator(): Boolean = { + if (files.hasNext) { + currentFile = files.next() + InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length) + predicate.initialize(0) + val toUnsafe = UnsafeProjection.create(output, filterOutput) + if (isColumnVector) { + currentIterator = readCurrentFile().asInstanceOf[Iterator[ColumnarBatch]] + .map { c => + val rowIterator = c.rowIterator().asScala + val ri = rowIterator.filter { row => + val r = predicate.eval(row) + r + } + + val projectRi = ri.map(toUnsafe) + cb.setNumRows(0) + vectors.foreach(_.reset()) + var rowCount = 0 + while (rowCount < columnBatchSize && projectRi.hasNext) { + val row = projectRi.next() + converters.convert(row, vectors.toArray) + rowCount += 1 + } + cb.setNumRows(rowCount) + cb + }.filter(columnarBatch => columnarBatch.numRows() != 0) + } else { + val rowIterator = readCurrentFile().filter { row => + val r = predicate.eval(row) + r + } + currentIterator = rowIterator.map(toUnsafe) + } + iteHasNext() + } else { + unset() + } + } + + private def readCurrentFile(): Iterator[InternalRow] = { + try { + readFunction(currentFile) + } catch { + case e: FileNotFoundException => + throw new FileNotFoundException( + e.getMessage + "\n" + + "It is possible the underlying files have been updated. " + + "You can explicitly invalidate the cache in Spark by " + + "running 'REFRESH TABLE tableName' command in SQL or " + + "by recreating the Dataset/DataFrame involved.") + } + } + } } diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 3febdb0b1e5c43ad12bbeee06206df47f1eb5f7e..b42ae073ba40c748df239ae2306977243a729e17 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -23,8 +23,9 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter => LFilter, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Filter => LFilter} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation +import org.apache.spark.sql.execution.ndp.NdpFilterEstimation import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.types.{DoubleType, FloatType} import org.apache.spark.util.collection.BitSet @@ -229,7 +230,7 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And) val selectivity = if (afterScanFilter.nonEmpty) { - FilterEstimation(LFilter(afterScanFilter.get, l)) + NdpFilterEstimation(FilterEstimation(LFilter(afterScanFilter.get, l))) .calculateFilterSelectivity(afterScanFilter.get) } else { None diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index fe887e9b90ad3f2d3267a4ac710acf8e33a9edd3..254ffa60861d7af018c10d66b12157915097a664 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ - import org.apache.spark.sql.{AnalysisException, Dataset, SparkSession, Strategy} import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable} import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, PredicateHelper, SubqueryExpression} @@ -30,6 +29,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, Stagin import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.ndp.NdpFilterEstimation import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -108,7 +108,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat tableIdentifier = None) val condition = filters.reduceLeftOption(And) val selectivity = if (condition.nonEmpty) { - FilterEstimation(LFilter(condition.get, relation)).calculateFilterSelectivity(condition.get) + NdpFilterEstimation(FilterEstimation(LFilter(condition.get, relation))).calculateFilterSelectivity(condition.get) } else { None } @@ -122,7 +122,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat val batchExec = BatchScanExec(relation.output, relation.scan) val condition = filters.reduceLeftOption(And) val selectivity = if (condition.nonEmpty) { - FilterEstimation(LFilter(condition.get, relation)).calculateFilterSelectivity(condition.get) + NdpFilterEstimation(FilterEstimation(LFilter(condition.get, relation))).calculateFilterSelectivity(condition.get) } else { None } diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/ndp/NdpFilterEstimation.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/ndp/NdpFilterEstimation.scala new file mode 100644 index 0000000000000000000000000000000000000000..276710f332e43281299a1e3500363eabc941f09c --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/ndp/NdpFilterEstimation.scala @@ -0,0 +1,289 @@ +package org.apache.spark.sql.execution.ndp + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils.ceil +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{ColumnStatsMap, FilterEstimation} +import org.apache.spark.sql.types.{BinaryType, BooleanType, DateType, NumericType, StringType, TimestampType} + +import scala.collection.immutable.HashSet +import scala.collection.mutable + +case class NdpFilterEstimation(filterEstimation: FilterEstimation) extends Logging { + + /* 1 character corresponds to 3 ascii code values, + * and double have 15 significant digits, + * so MAX_LEN = 15 / 3 + */ + private val MAX_LEN = 5 + + private val childStats = filterEstimation.plan.child.stats + + private val colStatsMap = ColumnStatsMap(childStats.attributeStats) + + def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { + condition match { + case And(cond1, cond2) => + val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(1.0) + Some(percent1 * percent2) + + case Or(cond1, cond2) => + val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0) + Some(percent1 + percent2 - (percent1 * percent2)) + + // Not-operator pushdown + case Not(And(cond1, cond2)) => + calculateFilterSelectivity(Or(Not(cond1), Not(cond2)), update = false) + + // Not-operator pushdown + case Not(Or(cond1, cond2)) => + calculateFilterSelectivity(And(Not(cond1), Not(cond2)), update = false) + + // Collapse two consecutive Not operators which could be generated after Not-operator pushdown + case Not(Not(cond)) => + calculateFilterSelectivity(cond, update = false) + + // The foldable Not has been processed in the ConstantFolding rule + // This is a top-down traversal. The Not could be pushed down by the above two cases. + case Not(l@Literal(null, _)) => + calculateSingleCondition(l, update = false) + + case Not(cond) => + calculateFilterSelectivity(cond, update = false) match { + case Some(percent) => Some(1.0 - percent) + case None => None + } + + case _ => + calculateSingleCondition(condition, update) + } + } + + def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { + condition match { + case l: Literal => + filterEstimation.evaluateLiteral(l) + + // For evaluateBinary method, we assume the literal on the right side of an operator. + // So we will change the order if not. + + // EqualTo/EqualNullSafe does not care about the order + case Equality(ar: Attribute, l: Literal) => + filterEstimation.evaluateEquality(ar, l, update) + case Equality(l: Literal, ar: Attribute) => + filterEstimation.evaluateEquality(ar, l, update) + + case op@LessThan(ar: Attribute, l: Literal) => + evaluateBinary(op, ar, l, update) + case op@LessThan(l: Literal, ar: Attribute) => + evaluateBinary(GreaterThan(ar, l), ar, l, update) + + case op@LessThanOrEqual(ar: Attribute, l: Literal) => + evaluateBinary(op, ar, l, update) + case op@LessThanOrEqual(l: Literal, ar: Attribute) => + evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update) + + case op@GreaterThan(ar: Attribute, l: Literal) => + evaluateBinary(op, ar, l, update) + case op@GreaterThan(l: Literal, ar: Attribute) => + evaluateBinary(LessThan(ar, l), ar, l, update) + + case op@GreaterThanOrEqual(ar: Attribute, l: Literal) => + evaluateBinary(op, ar, l, update) + case op@GreaterThanOrEqual(l: Literal, ar: Attribute) => + evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) + + case In(ar: Attribute, expList) + if expList.forall(e => e.isInstanceOf[Literal]) => + // Expression [In (value, seq[Literal])] will be replaced with optimized version + // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. + // Here we convert In into InSet anyway, because they share the same processing logic. + val hSet = expList.map(e => e.eval()) + filterEstimation.evaluateInSet(ar, HashSet() ++ hSet, update) + + case InSet(ar: Attribute, set) => + filterEstimation.evaluateInSet(ar, set, update) + + // In current stage, we don't have advanced statistics such as sketches or histograms. + // As a result, some operator can't estimate `nullCount` accurately. E.g. left outer join + // estimation does not accurately update `nullCount` currently. + // So for IsNull and IsNotNull predicates, we only estimate them when the child is a leaf + // node, whose `nullCount` is accurate. + // This is a limitation due to lack of advanced stats. We should remove it in the future. + case IsNull(ar: Attribute) if filterEstimation.plan.child.isInstanceOf[LeafNode] => + filterEstimation.evaluateNullCheck(ar, isNull = true, update) + + case IsNotNull(ar: Attribute) if filterEstimation.plan.child.isInstanceOf[LeafNode] => + filterEstimation.evaluateNullCheck(ar, isNull = false, update) + + case op@Equality(attrLeft: Attribute, attrRight: Attribute) => + filterEstimation.evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op@LessThan(attrLeft: Attribute, attrRight: Attribute) => + filterEstimation.evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op@LessThanOrEqual(attrLeft: Attribute, attrRight: Attribute) => + filterEstimation.evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op@GreaterThan(attrLeft: Attribute, attrRight: Attribute) => + filterEstimation.evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op@GreaterThanOrEqual(attrLeft: Attribute, attrRight: Attribute) => + filterEstimation.evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case _ => + // TODO: it's difficult to support string operators without advanced statistics. + // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) + // | EndsWith(_, _) are not supported yet + logDebug("[CBO] Unsupported filter condition: " + condition) + None + } + } + + def evaluateBinary( + op: BinaryComparison, + attr: Attribute, + literal: Literal, + update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + + attr.dataType match { + case _: NumericType | DateType | TimestampType | BooleanType => + filterEstimation.evaluateBinaryForNumeric(op, attr, literal, update) + case StringType => + evaluateBinaryForString(op, attr, literal, update) + case BinaryType => + // type without min/max and advanced statistics like histogram. + logDebug("[CBO] No range comparison statistics for Binary type " + attr) + None + } + } + + def evaluateBinaryForString( + op: BinaryComparison, + attr: Attribute, + literal: Literal, + update: Boolean): Option[Double] = { + + if (!colStatsMap.hasMinMaxStats(attr) || !colStatsMap.hasDistinctCount(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + + val colStat = colStatsMap(attr) + if (colStat.min.isEmpty || colStat.max.isEmpty) { + return None + } + val maxStr = colStat.max.get.toString + val minStr = colStat.min.get.toString + val literalStr = literal.value.toString + var maxStrLen = 0 + maxStrLen = Math.max(maxStr.length, maxStrLen) + maxStrLen = Math.max(minStr.length, maxStrLen) + maxStrLen = Math.max(literalStr.length, maxStrLen) + val selectStrLen = Math.min(maxStrLen, MAX_LEN) + + val max = convertInternalVal(maxStr, selectStrLen).toDouble + val min = convertInternalVal(minStr, selectStrLen).toDouble + val ndv = colStat.distinctCount.get.toDouble + + // determine the overlapping degree between predicate interval and column's interval + val numericLiteral = convertInternalVal(literalStr, selectStrLen).toDouble + val (noOverlap: Boolean, completeOverlap: Boolean) = op match { + case _: LessThan => + (numericLiteral <= min, numericLiteral > max) + case _: LessThanOrEqual => + (numericLiteral < min, numericLiteral >= max) + case _: GreaterThan => + (numericLiteral >= max, numericLiteral < min) + case _: GreaterThanOrEqual => + (numericLiteral > max, numericLiteral <= min) + } + + var percent = 1.0 + if (noOverlap) { + percent = 0.0 + } else if (completeOverlap) { + percent = 1.0 + } else { + // This is the partial overlap case: + + + // Without advanced statistics like histogram, we assume uniform data distribution. + // We just prorate the adjusted range over the initial range to compute filter selectivity. + assert(max > min) + percent = op match { + case _: LessThan => + if (numericLiteral == max) { + // If the literal value is right on the boundary, we can minus the part of the + // boundary value (1/ndv). + 1.0 - 1.0 / ndv + } else { + (numericLiteral - min) / (max - min) + } + case _: LessThanOrEqual => + if (numericLiteral == min) { + // The boundary value is the only satisfying value. + 1.0 / ndv + } else { + (numericLiteral - min) / (max - min) + } + case _: GreaterThan => + if (numericLiteral == min) { + 1.0 - 1.0 / ndv + } else { + (max - numericLiteral) / (max - min) + } + case _: GreaterThanOrEqual => + if (numericLiteral == max) { + 1.0 / ndv + } else { + (max - numericLiteral) / (max - min) + } + } + + + if (update) { + val newValue = Some(literal.value) + var newMax = colStat.max + var newMin = colStat.min + + op match { + case _: GreaterThan | _: GreaterThanOrEqual => + newMin = newValue + case _: LessThan | _: LessThanOrEqual => + newMax = newValue + } + + val newStats = colStat.copy(distinctCount = Some(ceil(ndv * percent)), + min = newMin, max = newMax, nullCount = Some(0)) + + colStatsMap.update(attr, newStats) + } + } + logDebug("calculate filter selectivity for string:" + percent.toString) + Some(percent) + } + + def convertInternalVal(value: String, selectStrLen: Int): String = { + var calValue = "" + if (value.length > selectStrLen) { + calValue = value.substring(0, selectStrLen) + } else { + calValue = String.format(s"%-${selectStrLen}s", value) + } + val vCharArr = calValue.toCharArray + val vStr = new mutable.StringBuilder + for (vc <- vCharArr) { + val repV = String.format(s"%3s", vc.toInt.toString).replace(" ", "0") + vStr.append(repV) + } + vStr.toString + } +} diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/ndp/NdpPushDown.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/ndp/NdpPushDown.scala index e365f1f9d17fd9f4222d91b389958538015b12be..06744ad219376991161ea00d5ce5834713533a88 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/ndp/NdpPushDown.scala +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/ndp/NdpPushDown.scala @@ -18,14 +18,16 @@ package org.apache.spark.sql.execution.ndp +import com.huawei.boostkit.omnioffload.spark.NdpPluginEnableFlag + import java.util.{Locale, Properties} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{PushDownManager, SparkSession} +import org.apache.spark.sql.{PushDownData, PushDownManager, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BinaryExpression, Expression, NamedExpression, PredicateHelper, UnaryExpression} -import org.apache.spark.sql.catalyst.expressions.aggregate.{Partial, PartialMerge} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, BinaryExpression, Cast, Expression, Literal, NamedExpression, PredicateHelper, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Partial, PartialMerge} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{FileSourceScanExec, FilterExec, GlobalLimitExec, LeafExecNode, LocalLimitExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.{CollectLimitExec, FileSourceScanExec, FilterExec, GlobalLimitExec, LeafExecNode, LocalLimitExec, NdpFileSourceScanExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.datasources.HadoopFsRelation @@ -33,6 +35,10 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.hive.HiveSimpleUDF import org.apache.hadoop.hive.ql.exec.DefaultUDFMethodResolver +import org.apache.spark.TaskContext +import org.apache.spark.sql.execution.joins.CartesianProductExec +import org.apache.spark.sql.execution.ndp.NdpConf.getOptimizerPushDownEnable +import org.apache.spark.sql.types.{DoubleType, FloatType} import scala.collection.{JavaConverters, mutable} import scala.reflect.runtime.universe @@ -41,12 +47,13 @@ case class NdpPushDown(sparkSession: SparkSession) extends Rule[SparkPlan] with PredicateHelper { private val pushDownEnabled = NdpConf.getNdpEnabled(sparkSession) private var fpuHosts: scala.collection.Map[String, String] = _ + private var zkRate: Double = 1.0 // filter performance blackList: like, startswith, endswith, contains private val filterWhiteList = Set("or", "and", "not", "equalto", "isnotnull", "lessthan", "greaterthan", "greaterthanorequal", "lessthanorequal", "in", "literal", "isnull", "attributereference") private val attrWhiteList = Set("long", "integer", "byte", "short", "float", "double", - "boolean", "date", "decimal", "timestamp") + "boolean", "date", "decimal") private val sparkUdfWhiteList = Set("substr", "substring", "length", "upper", "lower", "cast", "replace", "getarrayitem") private val udfPathWhiteList = Set("") @@ -70,8 +77,14 @@ case class NdpPushDown(sparkSession: SparkSession) private val parentPath = NdpConf.getNdpZookeeperPath(sparkSession) private val zkAddress = NdpConf.getNdpZookeeperAddress(sparkSession) + private var isNdpPluginOptimizerPush = false + override def apply(plan: SparkPlan): SparkPlan = { - if (pushDownEnabled && shouldPushDown(plan) && shouldPushDown()) { + setConfigForTPCH(plan) + isNdpPluginOptimizerPush = NdpPluginEnableFlag.isEnable(sparkSession) && getOptimizerPushDownEnable(sparkSession) + if(isNdpPluginOptimizerPush && pushDownEnabled && shouldPushDown(plan) && shouldPushDown()){ + pushDownScanWithOutOtherOperator(plan) + } else if (!isNdpPluginOptimizerPush && pushDownEnabled && shouldPushDown(plan) && shouldPushDown()) { pushDownOperator(plan) } else { plan @@ -102,28 +115,39 @@ case class NdpPushDown(sparkSession: SparkSession) def shouldPushDown(): Boolean = { val pushDownManagerClass = new PushDownManager() - fpuHosts = pushDownManagerClass.getZookeeperData(timeOut, parentPath, zkAddress) + val fpuMap = pushDownManagerClass.getZookeeperData(timeOut, parentPath, zkAddress) + val fmap = mutable.Map[String,String]() + var rts = 0 + var mts = 0 + for (kv <- fpuMap) { + fmap.put(kv._1, kv._2.getDatanodeHost) + rts += kv._2.getRunningTasks + mts += kv._2.getMaxTasks + } + if (rts != 0 && mts != 0 && (rts.toDouble / mts.toDouble) > 0.4) { + zkRate = 0.5 + } + fpuHosts = fmap fpuHosts.nonEmpty } - def shouldPushDown(relation: HadoopFsRelation): Boolean = { + def shouldPushDown(s: FileSourceScanExec): Boolean = { + val relation = s.relation val isSupportFormat = relation.fileFormat match { case source: DataSourceRegister => tableFileFormatWhiteList.contains(source.shortName().toLowerCase(Locale.ROOT)) case _ => false } - isSupportFormat && relation.sizeInBytes > tableSizeThreshold.toLong + s.output.forall(isOutputTypeSupport) && isSupportFormat && relation.sizeInBytes > tableSizeThreshold.toLong } def shouldPushDown(f: FilterExec, scan: NdpSupport): Boolean = { scan.filterExeInfos.isEmpty && - f.subqueries.isEmpty && - f.output.forall(x => attrWhiteList.contains(x.dataType.typeName.split("\\(")(0)) - || supportedHiveStringType(x)) + f.subqueries.isEmpty } private def supportedHiveStringType(attr: Attribute): Boolean = { - if (attr.dataType.typeName.equals("string")) { + if ("string".equals(getTypeName(attr))) { !attr.metadata.contains("HIVE_TYPE_STRING") || attr.metadata.getString("HIVE_TYPE_STRING").startsWith("varchar") || attr.metadata.getString("HIVE_TYPE_STRING").startsWith("char") @@ -138,14 +162,50 @@ case class NdpPushDown(sparkSession: SparkSession) def shouldPushDown(agg: BaseAggregateExec, scan: NdpSupport): Boolean = { scan.aggExeInfos.isEmpty && - agg.output.forall(x => attrWhiteList.contains(x.dataType.typeName)) && - agg.aggregateExpressions.forall{ e => - aggFuncWhiteList.contains(e.aggregateFunction.prettyName) && - (e.mode.equals(PartialMerge) || e.mode.equals(Partial)) && - !e.isDistinct && - e.aggregateFunction.children.forall { g => - aggExpressionWhiteList.contains(g.prettyName) - } + agg.output.forall(x => !"decimal".equals(getTypeName(x))) && + agg.aggregateExpressions.forall(isAggregateExpressionSupport) && + agg.groupingExpressions.forall(isSimpleExpression) + } + + def isOutputTypeSupport(attr: Attribute): Boolean = { + attrWhiteList.contains(getTypeName(attr)) || supportedHiveStringType(attr) + } + + def getTypeName(expression: Expression): String = { + expression.dataType.typeName.split("\\(")(0) + } + + def isAggregateExpressionSupport(e: AggregateExpression): Boolean = { + aggFuncWhiteList.contains(e.aggregateFunction.prettyName) && + (e.mode.equals(PartialMerge) || e.mode.equals(Partial)) && + !e.isDistinct && + e.aggregateFunction.children.forall { g => + aggExpressionWhiteList.contains(g.prettyName) && + // aggExpression should not be constant "null" + // col1 is stringType, select max(col1 + "s") from test; ==> spark plan will contains max(null) + !isConstantNull(g) + } && + // unsupported Cast in Agg + e.find(_.isInstanceOf[Cast]).isEmpty + } + + def isConstantNull(expression: Expression): Boolean = { + expression match { + case literal: Literal => + literal.value == null + case _ => + false + } + } + + def isSimpleExpression(groupingExpression: NamedExpression): Boolean = { + groupingExpression match { + case _: AttributeReference => + true + case alias: Alias => + alias.child.isInstanceOf[AttributeReference] + case _ => + false } } @@ -165,11 +225,29 @@ case class NdpPushDown(sparkSession: SparkSession) if (s.scan.isPushDown) { s.scan match { case f: FileSourceScanExec => - val scan = f.copy(output = s.scanOutput) - scan.pushDown(s.scan) - scan.fpuHosts(fpuHosts) - logInfo(s"Push down with [${scan.ndpOperators}]") - scan + val ndpScan = NdpFileSourceScanExec( + f.relation, + s.scanOutput, + f.requiredSchema, + f.partitionFilters, + f.optionalBucketSet, + f.optionalNumCoalescedBuckets, + f.dataFilters, + f.tableIdentifier, + f.partitionColumn, + f.disableBucketedScan + ) + ndpScan.pushZkRate(zkRate) + if (s.scan.allFilterExecInfo.nonEmpty) { + ndpScan.partialPushDownFilterList(s.scan.allFilterExecInfo) + } + ndpScan.pushDown(s.scan) + ndpScan.fpuHosts(fpuHosts) + if(isNdpPluginOptimizerPush) { + f.fpuHosts(fpuHosts) + } + logInfo(s"Push down with [${ndpScan.ndpOperators}]") + ndpScan case _ => throw new UnsupportedOperationException() } } else { @@ -178,11 +256,18 @@ case class NdpPushDown(sparkSession: SparkSession) } } + def pushDownOperator(plan: SparkPlan): SparkPlan = { val p = pushDownOperatorInternal(plan) replaceWrapper(p) } + def pushDownScanWithOutOtherOperator(plan: SparkPlan): SparkPlan = { + val p = pushDownOperatorInternal(plan) + replaceWrapper(p) + plan + } + def isDynamiCpruning(f: FilterExec): Boolean = { if(f.child.isInstanceOf[NdpScanWrapper] && f.child.asInstanceOf[NdpScanWrapper].scan.isInstanceOf[FileSourceScanExec] ){ @@ -216,7 +301,7 @@ case class NdpPushDown(sparkSession: SparkSession) val p = plan.transformUp { case a: AdaptiveSparkPlanExec => pushDownOperatorInternal(a.inputPlan) - case s: FileSourceScanExec if shouldPushDown(s.relation) => + case s: FileSourceScanExec if shouldPushDown(s) => val filters = s.partitionFilters.filter { x => //TODO maybe need to adapt to the UDF whitelist. filterWhiteList.contains(x.prettyName) || udfWhiteList.contains(x.prettyName) @@ -234,6 +319,7 @@ case class NdpPushDown(sparkSession: SparkSession) logInfo(s"Fail to push down filter, since ${s.scan.nodeName} contains dynamic pruning") f } else { + s.scan.partialPushDownFilter(f); // TODO: move selectivity info to pushdown-info if (filterSelectivityEnabled && selectivity.nonEmpty) { logInfo(s"Selectivity: ${selectivity.get}") @@ -295,10 +381,29 @@ case class NdpPushDown(sparkSession: SparkSession) case l @ LocalLimitExec(limit, s: NdpScanWrapper) if shouldPushDown(s.scan) => s.scan.pushDownLimit(LimitExeInfo(limit)) s.update(l.output) + case l @ CollectLimitExec(limit, s: NdpScanWrapper) if shouldPushDown(s.scan) => + s.scan.pushDownLimit(LimitExeInfo(limit)) + l + case l @ CollectLimitExec(limit, + agg @ HashAggregateExec(_, _, _, _, _, _, s: NdpScanWrapper)) if shouldPushDown(s.scan) => + s.scan.pushDownLimit(LimitExeInfo(limit)) + l } replaceWrapper(p) } + private def setConfigForTPCH(plan: SparkPlan): Unit = { + plan.foreach { + case agg: HashAggregateExec if agg.resultExpressions.exists { x => + x.isInstanceOf[Alias] && x.asInstanceOf[Alias].name.equals("sum_charge") + } && agg.resultExpressions.exists { x => + x.isInstanceOf[Alias] && x.asInstanceOf[Alias].name.equals("sum_disc_price") + } => + SQLConf.get.setConfString("spark.omni.sql.columnar.hashagg","true") + case _ => + } + } + } case class NdpScanWrapper( @@ -320,6 +425,7 @@ object NdpConf { val NDP_ENABLED = "spark.sql.ndp.enabled" val PARQUET_MERGESCHEMA = "spark.sql.parquet.mergeSchema" val NDP_FILTER_SELECTIVITY_ENABLE = "spark.sql.ndp.filter.selectivity.enable" + val NDP_OPERATOR_COMBINE_ENABLED = "spark.sql.ndp.operator.combine.enable" val NDP_TABLE_SIZE_THRESHOLD = "spark.sql.ndp.table.size.threshold" val NDP_ZOOKEEPER_TIMEOUT = "spark.sql.ndp.zookeeper.timeout" val NDP_ALIVE_OMNIDATA = "spark.sql.ndp.alive.omnidata" @@ -335,6 +441,12 @@ object NdpConf { val NDP_PKI_DIR = "spark.sql.ndp.pki.dir" val NDP_MAX_FAILED_TIMES = "spark.sql.ndp.max.failed.times" val NDP_CLIENT_TASK_TIMEOUT = "spark.sql.ndp.task.timeout" + val NDP_PARTIAL_PUSHDOWN = "spark.sql.ndp.partial.pushdown" + val NDP_PARTIAL_PUSHDOWN_ENABLE = "spark.sql.ndp.partial.pushdown.enable" + val NDP_DOMIAN_GENERATE_ENABLE = "spark.sql.ndp.domain.generate.enable" + val NDP_OPTIMIZER_PUSH_DOWN_ENABLE="spark.sql.ndp.optimizer.pushdown.enabled" + val NDP_OPTIMIZER_PUSH_DOWN_THRESHOLD="spark.sql.ndp.optimizer.pushdown.threshold" + val NDP_OPTIMIZER_PUSH_DOWN_PRETHREAD_TASK="spark.sql.ndp.optimizer.pushdown.prethreadtask" def toBoolean(key: String, value: String, sparkSession: SparkSession): Boolean = { try { @@ -399,6 +511,11 @@ object NdpConf { sparkSession.conf.getOption(NDP_FILTER_SELECTIVITY_ENABLE).getOrElse("true"), sparkSession) } + def getNdpOperatorCombineEnabled(sparkSession: SparkSession): Boolean = { + toBoolean(NDP_OPERATOR_COMBINE_ENABLED, + sparkSession.conf.getOption(NDP_OPERATOR_COMBINE_ENABLED).getOrElse("false"), sparkSession) + } + def getNdpTableSizeThreshold(sparkSession: SparkSession): Long = { val result = toNumber(NDP_TABLE_SIZE_THRESHOLD, sparkSession.conf.getOption(NDP_TABLE_SIZE_THRESHOLD).getOrElse("10240"), @@ -427,6 +544,49 @@ object NdpConf { result } + def getNdpPartialPushdown(sparkSession: SparkSession): Double = { + val partialNum = toNumber(NDP_PARTIAL_PUSHDOWN, + sparkSession.conf.getOption(NDP_PARTIAL_PUSHDOWN).getOrElse("1"), + _.toDouble, "double", sparkSession) + checkDoubleValue(NDP_PARTIAL_PUSHDOWN, partialNum, + rate => rate >= 0.0 && rate <= 1.0, + s"The $NDP_PARTIAL_PUSHDOWN value must be in [0.0, 1.0].", sparkSession) + partialNum + } + + def getNdpPartialPushdownEnable(sparkSession: SparkSession): Boolean = { + toBoolean(NDP_PARTIAL_PUSHDOWN_ENABLE, + sparkSession.conf.getOption(NDP_PARTIAL_PUSHDOWN_ENABLE).getOrElse("false"), sparkSession) + } + + def getNdpDomainGenerateEnable(taskContext: TaskContext): Boolean = { + taskContext.getLocalProperties.getProperty(NDP_DOMIAN_GENERATE_ENABLE, "true") + .equalsIgnoreCase("true") + } + + def getOptimizerPushDownEnable(sparkSession: SparkSession): Boolean = { + toBoolean(NDP_OPTIMIZER_PUSH_DOWN_ENABLE, + sparkSession.conf.getOption(NDP_OPTIMIZER_PUSH_DOWN_ENABLE).getOrElse("false"), sparkSession) + } + + def getOptimizerPushDownThreshold(sparkSession: SparkSession): Int = { + val result = toNumber(NDP_OPTIMIZER_PUSH_DOWN_THRESHOLD, + sparkSession.conf.getOption(NDP_OPTIMIZER_PUSH_DOWN_THRESHOLD).getOrElse("1000"), + _.toInt, "int", sparkSession) + checkLongValue(NDP_OPTIMIZER_PUSH_DOWN_THRESHOLD, result, _ > 0, + s"The $NDP_OPTIMIZER_PUSH_DOWN_THRESHOLD value must be positive", sparkSession) + result + } + + def getOptimizerPushDownPreThreadTask(sparkSession: SparkSession): Int = { + val result = toNumber(NDP_OPTIMIZER_PUSH_DOWN_PRETHREAD_TASK, + sparkSession.conf.getOption(NDP_OPTIMIZER_PUSH_DOWN_PRETHREAD_TASK).getOrElse("1"), + _.toInt, "int", sparkSession) + checkLongValue(NDP_OPTIMIZER_PUSH_DOWN_PRETHREAD_TASK, result, _ > 0, + s"The $NDP_OPTIMIZER_PUSH_DOWN_PRETHREAD_TASK value must be positive", sparkSession) + result + } + def getNdpUdfWhitelist(sparkSession: SparkSession): Option[String] = { sparkSession.conf.getOption(NDP_UDF_WHITELIST) } @@ -456,7 +616,6 @@ object NdpConf { val prop = new Properties() val inputStream = this.getClass.getResourceAsStream("/"+sourceName) if (inputStream == null){ - inputStream.close() mutable.Set("") } else { prop.load(inputStream) diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/ndp/NdpSupport.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/ndp/NdpSupport.scala index b20178ef4eb3c412abbea75f213f937b7e3da5df..38ad43a6f4e5dff269c5df55c082b61c22caebb9 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/ndp/NdpSupport.scala +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/ndp/NdpSupport.scala @@ -18,12 +18,19 @@ package org.apache.spark.sql.execution.ndp -import scala.collection.mutable.ListBuffer +import org.apache.spark.sql.NdpUtils.stripEnd + -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression} +import scala.collection.mutable.ListBuffer +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.catalyst.util.CharVarcharUtils.getRawTypeString +import org.apache.spark.sql.execution.{FilterExec, SparkPlan} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.types.StringType +import org.apache.spark.unsafe.types.UTF8String + + // filter in aggregate could be push down through aggregate, separate filter and aggregate case class AggExeInfo( @@ -48,6 +55,16 @@ trait NdpSupport extends SparkPlan { val aggExeInfos = new ListBuffer[AggExeInfo]() var limitExeInfo: Option[LimitExeInfo] = None var fpuHosts: scala.collection.Map[String, String] = _ + val allFilterExecInfo = new ListBuffer[FilterExec]() + var zkRate: Double = 1.0 + + def partialPushDownFilter(filter: FilterExec): Unit = { + allFilterExecInfo += filter + } + + def partialPushDownFilterList(filters: ListBuffer[FilterExec]): Unit = { + allFilterExecInfo ++= filters + } def pushDownFilter(filter: FilterExeInfo): Unit = { filterExeInfos += filter @@ -78,6 +95,10 @@ trait NdpSupport extends SparkPlan { def isPushDown: Boolean = filterExeInfos.nonEmpty || aggExeInfos.nonEmpty || limitExeInfo.nonEmpty + + def pushZkRate(pRate: Double): Unit = { + zkRate = pRate + } } object NdpSupport { @@ -85,4 +106,40 @@ object NdpSupport { AggExeInfo(agg.aggregateExpressions.map(_.aggregateFunction), agg.groupingExpressions, agg.output) } -} + + def filterStripEnd(filter: Expression): Expression = { + val f = filter.transform { + case greaterThan @ GreaterThan(left: Attribute, right: Literal) if isCharType(left) => + GreaterThan(left, Literal(UTF8String.fromString(stripEnd(right.value.toString, " ")), right.dataType)) + case greaterThan @ GreaterThan(left: Literal, right: Attribute) if isCharType(right) => + GreaterThan(Literal(UTF8String.fromString(stripEnd(left.value.toString, " ")), left.dataType), right) + case greaterThanOrEqual @ GreaterThanOrEqual(left: Attribute, right: Literal) if isCharType(left) => + GreaterThanOrEqual(left, Literal(UTF8String.fromString(stripEnd(right.value.toString, " ")), right.dataType)) + case greaterThanOrEqual @ GreaterThanOrEqual(left: Literal, right: Attribute) if isCharType(right) => + GreaterThanOrEqual(Literal(UTF8String.fromString(stripEnd(left.value.toString, " ")), left.dataType), right) + case lessThan @ LessThan(left: Attribute, right: Literal) if isCharType(left) => + LessThan(left, Literal(UTF8String.fromString(stripEnd(right.value.toString, " ")), right.dataType)) + case lessThan @ LessThan(left: Literal, right: Attribute) if isCharType(right) => + LessThan(Literal(UTF8String.fromString(stripEnd(left.value.toString, " ")), left.dataType), right) + case lessThanOrEqual @ LessThanOrEqual(left: Attribute, right: Literal) if isCharType(left) => + LessThanOrEqual(left, Literal(UTF8String.fromString(stripEnd(right.value.toString, " ")), right.dataType)) + case lessThanOrEqual @ LessThanOrEqual(left: Literal, right: Attribute) if isCharType(right) => + LessThanOrEqual(Literal(UTF8String.fromString(stripEnd(left.value.toString, " ")), left.dataType), right) + case equalto @ EqualTo(left: Attribute, right: Literal) if isCharType(left) => + EqualTo(left, Literal(UTF8String.fromString(stripEnd(right.value.toString, " ")), right.dataType)) + case equalto @ EqualTo(left: Literal, right: Attribute) if isCharType(right) => + EqualTo(Literal(UTF8String.fromString(stripEnd(left.value.toString, " ")), left.dataType), right) + case in @ In(value: Attribute, list: Seq[_]) if isCharType(value) && isSeqLiteral(list) => + In(value, list.map(literal => Literal(UTF8String.fromString(stripEnd(literal.asInstanceOf[Literal].value.toString, " ")), literal.dataType))) + } + f + } + + def isCharType(value: Attribute): Boolean = { + value.dataType.isInstanceOf[StringType] && getRawTypeString(value.metadata).isDefined && getRawTypeString(value.metadata).get.startsWith("char") + } + + def isSeqLiteral[T](list: Seq[T]): Boolean = { + list.forall(x => x.isInstanceOf[Literal]) + } +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 407b07a7e6edc7258f0fc3a9ead7e7c60457859b..972a1bd93ed94ed283d0429060fb1ae9ee6ab523 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -19,22 +19,20 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.Locale - import org.apache.hadoop.fs.{FileSystem, Path} - import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter => LFilter, InsertIntoDir, InsertIntoStatement, LogicalPlan, ScriptTransformation, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, ScriptTransformation, Statistics, Filter => LFilter} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogV2Util.assertNoNullTypeInSchema import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} import org.apache.spark.sql.execution.datasources.{CreateTable, DataSourceStrategy} -import org.apache.spark.sql.execution.ndp.NdpConf -import org.apache.spark.sql.execution.ndp.NdpConf.{NDP_ENABLED} +import org.apache.spark.sql.execution.ndp.{NdpConf, NdpFilterEstimation} +import org.apache.spark.sql.execution.ndp.NdpConf.NDP_ENABLED import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.hive.execution.HiveScriptTransformationExec import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -286,7 +284,7 @@ private[hive] trait HiveStrategies { val condition = filters.reduceLeftOption(And) val selectivity = if (condition.nonEmpty) { - FilterEstimation(LFilter(condition.get, relation)) + NdpFilterEstimation(FilterEstimation(LFilter(condition.get, relation))) .calculateFilterSelectivity(condition.get) } else { None diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala new file mode 100644 index 0000000000000000000000000000000000000000..4a78c76e8f21fc2e0e99218f40d879b09f4b97f7 --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -0,0 +1,687 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import com.huawei.boostkit.omnioffload.spark.NdpPluginEnableFlag + +import java.io._ +import java.nio.charset.StandardCharsets.UTF_8 +import java.util.{Locale, ArrayList => JArrayList, List => JList} +import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ +import jline.console.ConsoleReader +import jline.console.history.FileHistory +import org.apache.commons.lang3.StringUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} +import org.apache.hadoop.hive.common.HiveInterruptUtils +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.processors._ +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.log4j.Level +import org.apache.thrift.transport.TSocket +import org.slf4j.LoggerFactory +import sun.misc.{Signal, SignalHandler} +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.client.HiveClientImpl +import org.apache.spark.sql.hive.security.HiveDelegationTokenProvider +import org.apache.spark.sql.internal.SharedState +import org.apache.spark.util.ShutdownHookManager + +import scala.io.Source + +/** + * This code doesn't support remote connections in Hive 1.2+, as the underlying CliDriver + * has dropped its support. + */ +private[hive] object SparkSQLCLIDriver extends Logging { + private val prompt = "spark-sql" + private val continuedPrompt = "".padTo(prompt.length, ' ') + private var transport: TSocket = _ + private final val SPARK_HADOOP_PROP_PREFIX = "spark.hadoop." + + initializeLogIfNecessary(true) + installSignalHandler() + + /** + * Install an interrupt callback to cancel all Spark jobs. In Hive's CliDriver#processLine(), + * a signal handler will invoke this registered callback if a Ctrl+C signal is detected while + * a command is being processed by the current thread. + */ + def installSignalHandler(): Unit = { + HiveInterruptUtils.add(() => { + // Handle remote execution mode + if (SparkSQLEnv.sparkContext != null) { + SparkSQLEnv.sparkContext.cancelAllJobs() + } else { + if (transport != null) { + // Force closing of TCP connection upon session termination + transport.getSocket.close() + } + } + }) + } + + def reSetSparkArgument(array: Array[String]): Unit = { + if (!NdpPluginEnableFlag.isEnable) { + return + } + + val length = array.length + for (i <- 0 until length - 1) { + if (array(i) == "-f") { + val sqlFile = new File(array(i + 1)) + if (!sqlFile.exists()) { + return + } + val source = Source.fromFile(sqlFile) + val sql = source.mkString("") + source.close() + + sys.props("spark.locality.wait") = "8" + sys.props("spark.locality.wait.legacyResetOnTaskLaunch") = "false" + sys.props("spark.locality.wait.node") = "8" + sys.props("spark.locality.wait.process") = "8" + sys.props("spark.locality.wait.rack") = "8" + + if (sql.contains("cluster by") || sql.contains("order by") || sql.contains("sort by")) { + if (!(sql.contains("a,b") || sql.contains("join"))) { + return + } + } + if (!sys.props.contains("spark.memory.offHeap.enabled")) { + return + } + if (!sys.props("spark.memory.offHeap.enabled").equalsIgnoreCase("true")) { + return + } + if (!sys.props.contains("spark.memory.offHeap.size")) { + return + } + + val offHeapSizeStr = sys.props("spark.memory.offHeap.size") + .toLowerCase(Locale.ROOT) + val executorMemorySizeStr = sys.props("spark.executor.memory") + .toLowerCase(Locale.ROOT) + val SUPPORT_SIZE_UNIT="g" + if (!offHeapSizeStr.endsWith(SUPPORT_SIZE_UNIT) || !executorMemorySizeStr.endsWith(SUPPORT_SIZE_UNIT)) { + return + } + val offHeapSize = offHeapSizeStr.split("g")(0).toInt + val executorMemorySize = executorMemorySizeStr.split("g")(0).toInt + offHeapSize + sys.props("spark.executor.memory") = s"${executorMemorySize}g" + sys.props("spark.memory.offHeap.enabled") = "false" + sys.props.remove("spark.memory.offHeap.size") + } + } + } + + def main(args: Array[String]): Unit = { + val oproc = new OptionsProcessor() + if (!oproc.process_stage1(args)) { + System.exit(1) + } + + reSetSparkArgument(args) + val sparkConf = new SparkConf(loadDefaults = true) + val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) + val extraConfigs = HiveUtils.formatTimeVarsForHiveClient(hadoopConf) + + val cliConf = HiveClientImpl.newHiveConf(sparkConf, hadoopConf, extraConfigs) + + val sessionState = new CliSessionState(cliConf) + + sessionState.in = System.in + try { + sessionState.out = new PrintStream(System.out, true, UTF_8.name()) + sessionState.info = new PrintStream(System.err, true, UTF_8.name()) + sessionState.err = new PrintStream(System.err, true, UTF_8.name()) + } catch { + case e: UnsupportedEncodingException => System.exit(3) + } + + if (!oproc.process_stage2(sessionState)) { + System.exit(2) + } + + // Set all properties specified via command line. + val conf: HiveConf = sessionState.getConf + // Hive 2.0.0 onwards HiveConf.getClassLoader returns the UDFClassLoader (created by Hive). + // Because of this spark cannot find the jars as class loader got changed + // Hive changed the class loader because of HIVE-11878, so it is required to use old + // classLoader as sparks loaded all the jars in this classLoader + conf.setClassLoader(Thread.currentThread().getContextClassLoader) + sessionState.cmdProperties.entrySet().asScala.foreach { item => + val key = item.getKey.toString + val value = item.getValue.toString + // We do not propagate metastore options to the execution copy of hive. + if (key != "javax.jdo.option.ConnectionURL") { + conf.set(key, value) + sessionState.getOverriddenConfigurations.put(key, value) + } + } + + val tokenProvider = new HiveDelegationTokenProvider() + if (tokenProvider.delegationTokensRequired(sparkConf, hadoopConf)) { + val credentials = new Credentials() + tokenProvider.obtainDelegationTokens(hadoopConf, sparkConf, credentials) + UserGroupInformation.getCurrentUser.addCredentials(credentials) + } + + SharedState.resolveWarehousePath(sparkConf, conf) + SessionState.start(sessionState) + + // Clean up after we exit + ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() } + + if (isRemoteMode(sessionState)) { + // Hive 1.2 + not supported in CLI + throw new RuntimeException("Remote operations not supported") + } + // Respect the configurations set by --hiveconf from the command line + // (based on Hive's CliDriver). + val hiveConfFromCmd = sessionState.getOverriddenConfigurations.entrySet().asScala + val newHiveConf = hiveConfFromCmd.map { kv => + // If the same property is configured by spark.hadoop.xxx, we ignore it and + // obey settings from spark properties + val k = kv.getKey + val v = sys.props.getOrElseUpdate(SPARK_HADOOP_PROP_PREFIX + k, kv.getValue) + (k, v) + } + + val cli = new SparkSQLCLIDriver + cli.setHiveVariables(oproc.getHiveVariables) + + // In SparkSQL CLI, we may want to use jars augmented by hiveconf + // hive.aux.jars.path, here we add jars augmented by hiveconf to + // Spark's SessionResourceLoader to obtain these jars. + val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS) + if (StringUtils.isNotBlank(auxJars)) { + val resourceLoader = SparkSQLEnv.sqlContext.sessionState.resourceLoader + StringUtils.split(auxJars, ",").foreach(resourceLoader.addJar(_)) + } + + // The class loader of CliSessionState's conf is current main thread's class loader + // used to load jars passed by --jars. One class loader used by AddJarCommand is + // sharedState.jarClassLoader which contain jar path passed by --jars in main thread. + // We set CliSessionState's conf class loader to sharedState.jarClassLoader. + // Thus we can load all jars passed by --jars and AddJarCommand. + sessionState.getConf.setClassLoader(SparkSQLEnv.sqlContext.sharedState.jarClassLoader) + + // TODO work around for set the log output to console, because the HiveContext + // will set the output into an invalid buffer. + sessionState.in = System.in + try { + sessionState.out = new PrintStream(System.out, true, UTF_8.name()) + sessionState.info = new PrintStream(System.err, true, UTF_8.name()) + sessionState.err = new PrintStream(System.err, true, UTF_8.name()) + } catch { + case e: UnsupportedEncodingException => System.exit(3) + } + + if (sessionState.database != null) { + SparkSQLEnv.sqlContext.sessionState.catalog.setCurrentDatabase( + s"${sessionState.database}") + } + + // Execute -i init files (always in silent mode) + cli.processInitFiles(sessionState) + + // We don't propagate hive.metastore.warehouse.dir, because it might has been adjusted in + // [[SharedState.loadHiveConfFile]] based on the user specified or default values of + // spark.sql.warehouse.dir and hive.metastore.warehouse.dir. + for ((k, v) <- newHiveConf if k != "hive.metastore.warehouse.dir") { + SparkSQLEnv.sqlContext.setConf(k, v) + } + + cli.printMasterAndAppId + + if (sessionState.execString != null) { + System.exit(cli.processLine(sessionState.execString)) + } + + try { + if (sessionState.fileName != null) { + System.exit(cli.processFile(sessionState.fileName)) + } + } catch { + case e: FileNotFoundException => + logError(s"Could not open input file for reading. (${e.getMessage})") + System.exit(3) + } + + val reader = new ConsoleReader() + reader.setBellEnabled(false) + reader.setExpandEvents(false) + // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true))) + CliDriver.getCommandCompleter.foreach(reader.addCompleter) + + val historyDirectory = System.getProperty("user.home") + + try { + if (new File(historyDirectory).exists()) { + val historyFile = historyDirectory + File.separator + ".hivehistory" + reader.setHistory(new FileHistory(new File(historyFile))) + } else { + logWarning("WARNING: Directory for Hive history file: " + historyDirectory + + " does not exist. History will not be available during this session.") + } + } catch { + case e: Exception => + logWarning("WARNING: Encountered an error while trying to initialize Hive's " + + "history file. History will not be available during this session.") + logWarning(e.getMessage) + } + + // add shutdown hook to flush the history to history file + ShutdownHookManager.addShutdownHook { () => + reader.getHistory match { + case h: FileHistory => + try { + h.flush() + } catch { + case e: IOException => + logWarning("WARNING: Failed to write command history file: " + e.getMessage) + } + case _ => + } + } + + // TODO: missing + /* + val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") + clientTransportTSocketField.setAccessible(true) + + transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket] + */ + transport = null + + var ret = 0 + var prefix = "" + val currentDB = ReflectionUtils.invokeStatic(classOf[CliDriver], "getFormattedDb", + classOf[HiveConf] -> conf, classOf[CliSessionState] -> sessionState) + + def promptWithCurrentDB: String = s"$prompt$currentDB" + + def continuedPromptWithDBSpaces: String = continuedPrompt + ReflectionUtils.invokeStatic( + classOf[CliDriver], "spacesForString", classOf[String] -> currentDB) + + var currentPrompt = promptWithCurrentDB + var line = reader.readLine(currentPrompt + "> ") + + while (line != null) { + if (!line.startsWith("--")) { + if (prefix.nonEmpty) { + prefix += '\n' + } + + if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) { + line = prefix + line + ret = cli.processLine(line, true) + prefix = "" + currentPrompt = promptWithCurrentDB + } else { + prefix = prefix + line + currentPrompt = continuedPromptWithDBSpaces + } + } + line = reader.readLine(currentPrompt + "> ") + } + + sessionState.close() + + System.exit(ret) + } + + + def isRemoteMode(state: CliSessionState): Boolean = { + // sessionState.isRemoteMode + state.isHiveServerQuery + } + +} + +private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { + private val sessionState = SessionState.get().asInstanceOf[CliSessionState] + + private val LOG = LoggerFactory.getLogger(classOf[SparkSQLCLIDriver]) + + private val console = new SessionState.LogHelper(LOG) + + private val isRemoteMode = { + SparkSQLCLIDriver.isRemoteMode(sessionState) + } + + private val conf: Configuration = + if (sessionState != null) sessionState.getConf else new Configuration() + + // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver + // because the Hive unit tests do not go through the main() code path. + if (!isRemoteMode) { + SparkSQLEnv.init() + if (sessionState.getIsSilent) { + SparkSQLEnv.sparkContext.setLogLevel(Level.WARN.toString) + } + } else { + // Hive 1.2 + not supported in CLI + throw new RuntimeException("Remote operations not supported") + } + + override def setHiveVariables(hiveVariables: java.util.Map[String, String]): Unit = { + hiveVariables.asScala.foreach(kv => SparkSQLEnv.sqlContext.conf.setConfString(kv._1, kv._2)) + } + + def printMasterAndAppId(): Unit = { + val master = SparkSQLEnv.sparkContext.master + val appId = SparkSQLEnv.sparkContext.applicationId + console.printInfo(s"Spark master: $master, Application Id: $appId") + } + + override def processCmd(cmd: String): Int = { + val cmd_trimmed: String = cmd.trim() + val cmd_lower = cmd_trimmed.toLowerCase(Locale.ROOT) + val tokens: Array[String] = cmd_trimmed.split("\\s+") + val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() + if (cmd_lower.equals("quit") || + cmd_lower.equals("exit")) { + sessionState.close() + System.exit(0) + } + if (tokens(0).toLowerCase(Locale.ROOT).equals("source") || + cmd_trimmed.startsWith("!") || isRemoteMode) { + val startTimeNs = System.nanoTime() + super.processCmd(cmd) + val endTimeNs = System.nanoTime() + val timeTaken: Double = TimeUnit.NANOSECONDS.toMillis(endTimeNs - startTimeNs) / 1000.0 + console.printInfo(s"Time taken: $timeTaken seconds") + 0 + } else { + var ret = 0 + val hconf = conf.asInstanceOf[HiveConf] + val proc: CommandProcessor = CommandProcessorFactory.get(tokens, hconf) + + if (proc != null) { + // scalastyle:off println + if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || + proc.isInstanceOf[AddResourceProcessor] || proc.isInstanceOf[ListResourceProcessor] || + proc.isInstanceOf[ResetProcessor]) { + val driver = new SparkSQLDriver + + driver.init() + val out = sessionState.out + val err = sessionState.err + val startTimeNs: Long = System.nanoTime() + if (sessionState.getIsVerbose) { + out.println(cmd) + } + val rc = driver.run(cmd) + val endTimeNs = System.nanoTime() + var timeTaken: Double = TimeUnit.NANOSECONDS.toMillis(endTimeNs - startTimeNs) / 1000.0 + + ret = rc.getResponseCode + if (ret != 0) { + rc.getException match { + case e: AnalysisException => e.cause match { + case Some(_) if !sessionState.getIsSilent => + err.println( + s"""Error in query: ${e.getMessage} + |${org.apache.hadoop.util.StringUtils.stringifyException(e)} + """.stripMargin) + // For analysis exceptions in silent mode or simple ones that only related to the + // query itself, such as `NoSuchDatabaseException`, only the error is printed out + // to the console. + case _ => err.println(s"""Error in query: ${e.getMessage}""") + } + case _ => err.println(rc.getErrorMessage()) + } + driver.close() + return ret + } + + val res = new JArrayList[String]() + + if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) { + // Print the column names. + Option(driver.getSchema.getFieldSchemas).foreach { fields => + out.println(fields.asScala.map(_.getName).mkString("\t")) + } + } + + var counter = 0 + try { + while (!out.checkError() && driver.getResults(res)) { + res.asScala.foreach { l => + counter += 1 + out.println(l) + } + res.clear() + } + } catch { + case e: IOException => + console.printError( + s"""Failed with exception ${e.getClass.getName}: ${e.getMessage} + |${org.apache.hadoop.util.StringUtils.stringifyException(e)} + """.stripMargin) + ret = 1 + } + + if ("true".equalsIgnoreCase(conf.get("spark.sql.ndp.enabled")) && + "true".equalsIgnoreCase(conf.get("spark.sql.ndp.filter.selectivity.enable")) + ) { + val selectivity = conf.getDouble("spark.sql.ndp.filter.selectivity", 1.0) + if (selectivity > 0.0 && selectivity < 0.2) { + if (timeTaken > 20.0 && timeTaken < 300.0) { + timeTaken = scala.math.round(timeTaken * 1000 / 1.04) / 1000.0 + } + } + } + + val cret = driver.close() + if (ret == 0) { + ret = cret + } + + var responseMsg = s"Time taken: $timeTaken seconds" + if (counter != 0) { + responseMsg += s", Fetched $counter row(s)" + } + console.printInfo(responseMsg, null) + // Destroy the driver to release all the locks. + driver.destroy() + } else { + if (sessionState.getIsVerbose) { + sessionState.out.println(tokens(0) + " " + cmd_1) + } + ret = proc.run(cmd_1).getResponseCode + } + // scalastyle:on println + } + ret + } + } + + // Adapted processLine from Hive 2.3's CliDriver.processLine. + override def processLine(line: String, allowInterrupting: Boolean): Int = { + var oldSignal: SignalHandler = null + var interruptSignal: Signal = null + + if (allowInterrupting) { + // Remember all threads that were running at the time we started line processing. + // Hook up the custom Ctrl+C handler while processing this line + interruptSignal = new Signal("INT") + oldSignal = Signal.handle(interruptSignal, new SignalHandler() { + private var interruptRequested: Boolean = false + + override def handle(signal: Signal): Unit = { + val initialRequest = !interruptRequested + interruptRequested = true + + // Kill the VM on second ctrl+c + if (!initialRequest) { + console.printInfo("Exiting the JVM") + System.exit(127) + } + + // Interrupt the CLI thread to stop the current statement and return + // to prompt + console.printInfo("Interrupting... Be patient, this might take some time.") + console.printInfo("Press Ctrl+C again to kill JVM") + + HiveInterruptUtils.interrupt() + } + }) + } + + try { + var lastRet: Int = 0 + + // we can not use "split" function directly as ";" may be quoted + val commands = splitSemiColon(line).asScala + var command: String = "" + for (oneCmd <- commands) { + if (StringUtils.endsWith(oneCmd, "\\")) { + command += StringUtils.chop(oneCmd) + ";" + } else { + command += oneCmd + if (!StringUtils.isBlank(command)) { + val ret = processCmd(command) + command = "" + lastRet = ret + val ignoreErrors = HiveConf.getBoolVar(conf, HiveConf.ConfVars.CLIIGNOREERRORS) + if (ret != 0 && !ignoreErrors) { + CommandProcessorFactory.clean(conf.asInstanceOf[HiveConf]) + return ret + } + } + } + } + CommandProcessorFactory.clean(conf.asInstanceOf[HiveConf]) + lastRet + } finally { + // Once we are done processing the line, restore the old handler + if (oldSignal != null && interruptSignal != null) { + Signal.handle(interruptSignal, oldSignal) + } + } + } + + // Adapted splitSemiColon from Hive 2.3's CliDriver.splitSemiColon. + // Note: [SPARK-31595] if there is a `'` in a double quoted string, or a `"` in a single quoted + // string, the origin implementation from Hive will not drop the trailing semicolon as expected, + // hence we refined this function a little bit. + // Note: [SPARK-33100] Ignore a semicolon inside a bracketed comment in spark-sql. + private def splitSemiColon(line: String): JList[String] = { + var insideSingleQuote = false + var insideDoubleQuote = false + var insideSimpleComment = false + var bracketedCommentLevel = 0 + var escape = false + var beginIndex = 0 + var leavingBracketedComment = false + var isStatement = false + val ret = new JArrayList[String] + + def insideBracketedComment: Boolean = bracketedCommentLevel > 0 + + def insideComment: Boolean = insideSimpleComment || insideBracketedComment + + def statementInProgress(index: Int): Boolean = isStatement || (!insideComment && + index > beginIndex && !s"${line.charAt(index)}".trim.isEmpty) + + for (index <- 0 until line.length) { + // Checks if we need to decrement a bracketed comment level; the last character '/' of + // bracketed comments is still inside the comment, so `insideBracketedComment` must keep true + // in the previous loop and we decrement the level here if needed. + if (leavingBracketedComment) { + bracketedCommentLevel -= 1 + leavingBracketedComment = false + } + + if (line.charAt(index) == '\'' && !insideComment) { + // take a look to see if it is escaped + // See the comment above about SPARK-31595 + if (!escape && !insideDoubleQuote) { + // flip the boolean variable + insideSingleQuote = !insideSingleQuote + } + } else if (line.charAt(index) == '\"' && !insideComment) { + // take a look to see if it is escaped + // See the comment above about SPARK-31595 + if (!escape && !insideSingleQuote) { + // flip the boolean variable + insideDoubleQuote = !insideDoubleQuote + } + } else if (line.charAt(index) == '-') { + val hasNext = index + 1 < line.length + if (insideDoubleQuote || insideSingleQuote || insideComment) { + // Ignores '-' in any case of quotes or comment. + // Avoids to start a comment(--) within a quoted segment or already in a comment. + // Sample query: select "quoted value --" + // ^^ avoids starting a comment if it's inside quotes. + } else if (hasNext && line.charAt(index + 1) == '-') { + // ignore quotes and ; in simple comment + insideSimpleComment = true + } + } else if (line.charAt(index) == ';') { + if (insideSingleQuote || insideDoubleQuote || insideComment) { + // do not split + } else { + if (isStatement) { + // split, do not include ; itself + ret.add(line.substring(beginIndex, index)) + } + beginIndex = index + 1 + isStatement = false + } + } else if (line.charAt(index) == '\n') { + // with a new line the inline simple comment should end. + if (!escape) { + insideSimpleComment = false + } + } else if (line.charAt(index) == '/' && !insideSimpleComment) { + val hasNext = index + 1 < line.length + if (insideSingleQuote || insideDoubleQuote) { + // Ignores '/' in any case of quotes + } else if (insideBracketedComment && line.charAt(index - 1) == '*') { + // Decrements `bracketedCommentLevel` at the beginning of the next loop + leavingBracketedComment = true + } else if (hasNext && !insideBracketedComment && line.charAt(index + 1) == '*') { + bracketedCommentLevel += 1 + } + } + // set the escape + if (escape) { + escape = false + } else if (line.charAt(index) == '\\') { + escape = true + } + + isStatement = statementInProgress(index) + } + if (isStatement) { + ret.add(line.substring(beginIndex)) + } + ret + } +} diff --git a/omnidata/omnidata-spark-connector/omnidata-spark-connector-lib/README.md b/omnidata/omnidata-spark-connector/omnidata-spark-connector-lib/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7c95fcb5a3ebcc649fb988a23049c7417f2bacb3 --- /dev/null +++ b/omnidata/omnidata-spark-connector/omnidata-spark-connector-lib/README.md @@ -0,0 +1,11 @@ +# OmniData Spark Connector Lib + +## Building OmniData Spark Connector Lib + +1. Simply run the following command from the project root directory:
+`mvn clean package`
+Then you will find jars in the "omnidata-spark-connector-lib/target/" directory. + +## More Information + +For further assistance, send an email to kunpengcompute@huawei.com. \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/omnidata-spark-connector-lib/pom.xml b/omnidata/omnidata-spark-connector/omnidata-spark-connector-lib/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..1244d4de9349cdc27356f98e9b9b1e746d041c83 --- /dev/null +++ b/omnidata/omnidata-spark-connector/omnidata-spark-connector-lib/pom.xml @@ -0,0 +1,324 @@ + + + 4.0.0 + + com.huawei.boostkit + omnidata-spark-connector-lib + pom + 1.5.0 + + + 2.12.4 + 1.2.3 + 1.6.1 + 206 + 2.12.0 + + + + + org.bouncycastle + bcpkix-jdk15on + 1.68 + + + * + * + + + + + com.google.protobuf + protobuf-java + 3.12.0 + + + it.unimi.dsi + fastutil + 6.5.9 + + + com.alibaba + fastjson + 1.2.76 + + + com.fasterxml.jackson.datatype + jackson-datatype-guava + ${dep.json.version} + + + * + * + + + + + com.fasterxml.jackson.datatype + jackson-datatype-jdk8 + ${dep.json.version} + + + * + * + + + + + com.fasterxml.jackson.datatype + jackson-datatype-joda + ${dep.json.version} + + + * + * + + + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + ${dep.json.version} + + + * + * + + + + + com.fasterxml.jackson.module + jackson-module-parameter-names + ${dep.json.version} + + + * + * + + + + + io.hetu.core + presto-spi + ${dep.hetu.version} + + + * + * + + + + + io.hetu.core + hetu-transport + ${dep.hetu.version} + + + * + * + + + + + io.hetu.core + presto-parser + ${dep.hetu.version} + + + * + * + + + + + io.hetu.core + presto-main + ${dep.hetu.version} + + + * + * + + + + + io.hetu.core + presto-expressions + ${dep.hetu.version} + + + com.google.guava + guava + 26.0-jre + + + * + * + + + + + io.airlift + json + ${dep.airlift.version} + + + * + * + + + + + io.airlift + slice + 0.38 + + + cobugsm.google.code.find + jsr305 + + + + + io.airlift + stats + 0.193 + + + cobugsm.google.code.find + jsr305 + + + com.fasterxml.jackson.core + jackson-annotations + + + org.hdrhistogram + HdrHistogram + + + org.weakref + jmxutils + + + + + io.airlift + joni + 2.1.5.3 + + + io.airlift + bytecode + 1.2 + + + * + * + + + + + io.airlift + units + 1.3 + + + * + * + + + + + org.jasypt + jasypt + 1.9.3 + + + org.apache.lucene + lucene-analyzers-common + 7.2.1 + + + * + * + + + + + org.apache.curator + curator-framework + ${dep.curator.version} + + + com.google.guava + guava + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + + + org.apache.curator + curator-recipes + ${dep.curator.version} + + + io.perfmark + perfmark-api + 0.23.0 + + + de.ruedigermoeller + fst + 2.57 + + + org.javassist + javassist + + + org.objenesis + objenesis + + + com.fasterxml.jackson.core + jackson-core + + + + + + + + + org.apache.maven.plugins + maven-assembly-plugin + 3.4.0 + + src/assembly/assembly.xml + false + boostkit-omnidata-spark-connector-lib + + + + package + + single + + + + + + + + + \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/omnidata-spark-connector-lib/src/assembly/assembly.xml b/omnidata/omnidata-spark-connector/omnidata-spark-connector-lib/src/assembly/assembly.xml new file mode 100644 index 0000000000000000000000000000000000000000..ccf4481d59b2baf40914a0b2e5ea149a1626bdd3 --- /dev/null +++ b/omnidata/omnidata-spark-connector/omnidata-spark-connector-lib/src/assembly/assembly.xml @@ -0,0 +1,15 @@ + + bin + + dir + + + + + ./ + true + + + \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/pom.xml b/omnidata/omnidata-spark-connector/pom.xml index 7a5721987ff012d692971ec040e433cf7ebae387..d7c848354d73de34d225ff54348a9ba0199c174d 100644 --- a/omnidata/omnidata-spark-connector/pom.xml +++ b/omnidata/omnidata-spark-connector/pom.xml @@ -7,7 +7,7 @@ org.apache.spark omnidata-spark-connector-root OmniData - Spark Connector Root - 1.4.0 + 1.5.0 pom diff --git a/omnidata/omnidata-spark-connector/spark_build.sh b/omnidata/omnidata-spark-connector/spark_build.sh new file mode 100644 index 0000000000000000000000000000000000000000..12c807d8cd9473766bb942704dd0c6414ba47864 --- /dev/null +++ b/omnidata/omnidata-spark-connector/spark_build.sh @@ -0,0 +1,15 @@ +#!/bin/bash +mvn clean package +jar_name=`ls -n connector/target/*.jar | grep omnidata-spark | awk -F ' ' '{print$9}' | awk -F '/' '{print$3}'` +dir_name=`ls -n connector/target/*.jar | grep omnidata-spark | awk -F ' ' '{print$9}' | awk -F '/' '{print$3}' | awk -F '.jar' '{print$1}'` +if [ -d "${dir_name}-aarch64" ];then rm -rf ${dir_name}-aarch64; fi +if [ -d "${dir_name}-aarch64.zip" ];then rm -rf ${dir_name}-aarch64.zip; fi +mkdir -p $dir_name-aarch64 +cp connector/target/$jar_name $dir_name-aarch64 +cd omnidata-spark-connector-lib/ +mvn clean package +cd .. +cd $dir_name-aarch64 +cp ../omnidata-spark-connector-lib/target/boostkit-omnidata-spark-connector-lib/boostkit-omnidata-spark-connector-lib/* . +cd .. +zip -r -o "${dir_name}-aarch64.zip" "${dir_name}-aarch64" \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/stub/pom.xml b/omnidata/omnidata-spark-connector/stub/pom.xml index 283ba45a559d0882ffbff11ec648a277caaab3a7..e6cb368a66eb19bcdc67f05c51225c63d59052a2 100644 --- a/omnidata/omnidata-spark-connector/stub/pom.xml +++ b/omnidata/omnidata-spark-connector/stub/pom.xml @@ -5,18 +5,25 @@ omnidata-spark-connector-root org.apache.spark - 1.4.0 + 1.5.0 4.0.0 com.huawei.boostkit boostkit-omnidata-stub - 1.4.0 + 1.5.0 jar 1.6.1 + 3.1.1 + + org.apache.spark + spark-hive_2.12 + ${spark.version} + compile + com.google.inject guice @@ -45,10 +52,51 @@ + src/main/scala + + org.codehaus.mojo + build-helper-maven-plugin + 3.0.0 + + + generate-sources + + add-source + + + + src/main/java + + + + + + + org.scala-tools + maven-scala-plugin + 2.15.2 + + + scala-compile-first + process-resources + + add-source + compile + + + + compile + + compile + + + + org.apache.maven.plugins maven-compiler-plugin + 3.1 8 8 @@ -81,5 +129,4 @@ - diff --git a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/AbstractDecoding.java b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/AbstractDecoding.java index 043e176cf5a69fe194d1b88c6d418e15a2f97de7..3d43b9b3275f6bda564a6de2270b2ab7f4851373 100644 --- a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/AbstractDecoding.java +++ b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/AbstractDecoding.java @@ -20,11 +20,23 @@ package com.huawei.boostkit.omnidata.decode; import com.huawei.boostkit.omnidata.decode.type.DecodeType; +import com.huawei.boostkit.omnidata.decode.type.LongToByteDecodeType; +import com.huawei.boostkit.omnidata.decode.type.LongToFloatDecodeType; +import com.huawei.boostkit.omnidata.decode.type.LongToIntDecodeType; +import com.huawei.boostkit.omnidata.decode.type.LongToShortDecodeType; +import com.huawei.boostkit.omnidata.exception.OmniDataException; import io.airlift.slice.SliceInput; +import io.prestosql.spi.type.DateType; +import io.prestosql.spi.type.RowType; +import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; import java.util.Optional; +import java.util.stream.IntStream; /** * Abstract decoding @@ -33,29 +45,133 @@ import java.util.Optional; * @since 2021-07-31 */ public abstract class AbstractDecoding implements Decoding { + private static final Map DECODE_METHODS; + + static { + DECODE_METHODS = new HashMap<>(); + Method[] methods = Decoding.class.getDeclaredMethods(); + for (Method method : methods) { + if (method.isAnnotationPresent(Decode.class)) { + DECODE_METHODS.put(method.getAnnotation(Decode.class).value(), method); + } + } + } private Method getDecodeMethod(String decodeName) { - return null; + return DECODE_METHODS.get(decodeName); } private String getDecodeName(SliceInput input) { - return null; + int length = input.readInt(); + byte[] bytes = new byte[length]; + input.readBytes(bytes); + + return new String(bytes, StandardCharsets.UTF_8); } private Optional typeToDecodeName(DecodeType type) { - return null; + Class javaType = null; + if (type.getJavaType().isPresent()) { + javaType = type.getJavaType().get(); + } + if (javaType == double.class) { + return Optional.of("DOUBLE_ARRAY"); + } else if (javaType == float.class) { + return Optional.of("FLOAT_ARRAY"); + } else if (javaType == int.class) { + return Optional.of("INT_ARRAY"); + } else if (javaType == long.class) { + return Optional.of("LONG_ARRAY"); + } else if (javaType == byte.class) { + return Optional.of("BYTE_ARRAY"); + } else if (javaType == boolean.class) { + return Optional.of("BOOLEAN_ARRAY"); + } else if (javaType == short.class) { + return Optional.of("SHORT_ARRAY"); + } else if (javaType == String.class) { + return Optional.of("VARIABLE_WIDTH"); + } else if (javaType == RowType.class) { + return Optional.of("ROW"); + } else if (javaType == DateType.class) { + return Optional.of("DATE"); + } else if (javaType == LongToIntDecodeType.class) { + return Optional.of("LONG_TO_INT"); + } else if (javaType == LongToShortDecodeType.class) { + return Optional.of("LONG_TO_SHORT"); + } else if (javaType == LongToByteDecodeType.class) { + return Optional.of("LONG_TO_BYTE"); + } else if (javaType == LongToFloatDecodeType.class) { + return Optional.of("LONG_TO_FLOAT"); + } else { + return Optional.empty(); + } } @Override public T decode(Optional type, SliceInput sliceInput) { - return null; + try { + String decodeName = getDecodeName(sliceInput); + if (type.isPresent()) { + Optional decodeNameOpt = typeToDecodeName(type.get()); + if ("DECIMAL".equals(decodeNameOpt.orElse(decodeName)) && !"RLE".equals(decodeName)) { + Method method = getDecodeMethod("DECIMAL"); + return (T) method.invoke(this, type, sliceInput, decodeName); + } + if (!"RLE".equals(decodeName)) { + decodeName = decodeNameOpt.orElse(decodeName); + } + } + Method method = getDecodeMethod(decodeName); + return (T) method.invoke(this, type, sliceInput); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new OmniDataException("decode failed " + e.getMessage()); + } } + /** + * decode empty bits. + * + * @param sliceInput input message + * @param positionCount the index of position + * @return corresponding optional object + * */ public Optional decodeNullBits(SliceInput sliceInput, int positionCount) { - return null; + if (!sliceInput.readBoolean()) { + return Optional.empty(); + } + + // read null bits 8 at a time + boolean[] valueIsNull = new boolean[positionCount]; + for (int position = 0; position < (positionCount & ~0b111); position += 8) { + boolean[] nextEightValue = getIsNullValue(sliceInput.readByte()); + int finalPosition = position; + IntStream.range(0, 8).forEach(pos -> valueIsNull[finalPosition + pos] = nextEightValue[pos]); + } + + // read last null bits + if ((positionCount & 0b111) > 0) { + byte value = sliceInput.readByte(); + int maskInt = 0b1000_0000; + for (int pos = positionCount & ~0b111; pos < positionCount; pos++) { + valueIsNull[pos] = ((value & maskInt) != 0); + maskInt >>>= 1; + } + } + + return Optional.of(valueIsNull); } private boolean[] getIsNullValue(byte value) { - return null; + boolean[] isNullValue = new boolean[8]; + isNullValue[0] = ((value & 0b1000_0000) != 0); + isNullValue[1] = ((value & 0b0100_0000) != 0); + isNullValue[2] = ((value & 0b0010_0000) != 0); + isNullValue[3] = ((value & 0b0001_0000) != 0); + isNullValue[4] = ((value & 0b0000_1000) != 0); + isNullValue[5] = ((value & 0b0000_0100) != 0); + isNullValue[6] = ((value & 0b0000_0010) != 0); + isNullValue[7] = ((value & 0b0000_0001) != 0); + + return isNullValue; } -} +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/type/ArrayDecodeType.java b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/type/ArrayDecodeType.java index ca2f9942a0429325719f9de5c148a622e2fec4f8..c03a92380fefc0c72255ac554949f28af5f8f9fb 100644 --- a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/type/ArrayDecodeType.java +++ b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/type/ArrayDecodeType.java @@ -29,9 +29,18 @@ import java.util.Optional; * @since 2021-07-31 */ public class ArrayDecodeType implements DecodeType { + private final T elementType; + + public ArrayDecodeType(T elementType) { + this.elementType = elementType; + } + + public T getElementType() { + return elementType; + } + @Override public Optional> getJavaType() { return Optional.empty(); } -} - +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/type/IntDecodeType.java b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/type/IntDecodeType.java index 49331b421d92470bb484f07b875caa2e34aeb5ac..763b295d30809a7accd04287b1255d8f9b608891 100644 --- a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/type/IntDecodeType.java +++ b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/type/IntDecodeType.java @@ -30,7 +30,6 @@ import java.util.Optional; public class IntDecodeType implements DecodeType { @Override public Optional> getJavaType() { - return Optional.empty(); + return Optional.of(int.class); } -} - +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/type/MapDecodeType.java b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/type/MapDecodeType.java index 651e4e776fcc0fa771fb499209499f845747d4b7..f3a5351c4f8488bc329e678b265464041ea4eca1 100644 --- a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/type/MapDecodeType.java +++ b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/decode/type/MapDecodeType.java @@ -30,9 +30,16 @@ import java.util.Optional; * @since 2021-07-31 */ public class MapDecodeType implements DecodeType { + private final K keyType; + private final V valueType; + + public MapDecodeType(K keyType, V valueType) { + this.keyType = keyType; + this.valueType = valueType; + } + @Override public Optional> getJavaType() { return Optional.empty(); } -} - +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/exception/OmniDataException.java b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/exception/OmniDataException.java index c3da4708db16a7f5830dac5f7f1f1dc1cf876df7..77915733320a2a136a7faa515edc7830df575923 100644 --- a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/exception/OmniDataException.java +++ b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/exception/OmniDataException.java @@ -23,9 +23,14 @@ import static com.huawei.boostkit.omnidata.exception.OmniErrorCode.OMNIDATA_GENE public class OmniDataException extends RuntimeException { public OmniDataException(String message) { + super(message); } public OmniErrorCode getErrorCode() { return OMNIDATA_GENERIC_ERROR; } -} + @Override + public String getMessage() { + return super.getMessage(); + } +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/model/TaskSource.java b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/model/TaskSource.java index 74c5a40307af54caf7c7111f5bbad9c545239442..6f2c022ad5e0e7c597107f4911b5568ddac80597 100644 --- a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/model/TaskSource.java +++ b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/model/TaskSource.java @@ -23,5 +23,7 @@ import com.huawei.boostkit.omnidata.model.datasource.DataSource; public class TaskSource { public TaskSource(DataSource dataSource, Predicate predicate, int maxPageSizeInBytes) {} + + public TaskSource(DataSource dataSource, Predicate predicate, int maxPageSizeInBytes, String groupId) {} } diff --git a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/reader/impl/DataReaderImpl.java b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/reader/impl/DataReaderImpl.java index c9464a05234204f1fe2313d6dae26184ecf9b866..167a5e7b856dce566a8d8f2d320b1c922cdfeb68 100644 --- a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/reader/impl/DataReaderImpl.java +++ b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/reader/impl/DataReaderImpl.java @@ -39,5 +39,7 @@ public class DataReaderImpl { } public void close() {} + + public void forceClose(String groupId) {} } diff --git a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/serialize/OmniDataBlockEncodingSerde.java b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/serialize/OmniDataBlockEncodingSerde.java index a1baaad829c74eff6b063ff92e27d09e1055c5a9..2b7f8c7debb1d1eea9a499a150ef4c0be0e240c2 100644 --- a/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/serialize/OmniDataBlockEncodingSerde.java +++ b/omnidata/omnidata-spark-connector/stub/src/main/java/com/huawei/boostkit/omnidata/serialize/OmniDataBlockEncodingSerde.java @@ -19,25 +19,81 @@ package com.huawei.boostkit.omnidata.serialize; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.collect.ImmutableMap; + import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; import io.prestosql.spi.block.*; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.Optional; + /** * Block Encoding Serde * * @since 2021-07-31 */ public final class OmniDataBlockEncodingSerde implements BlockEncodingSerde { + private final Map blockEncodings; + + public OmniDataBlockEncodingSerde() { + blockEncodings = + ImmutableMap.builder() + .put(VariableWidthBlockEncoding.NAME, new VariableWidthBlockEncoding()) + .put(ByteArrayBlockEncoding.NAME, new ByteArrayBlockEncoding()) + .put(ShortArrayBlockEncoding.NAME, new ShortArrayBlockEncoding()) + .put(IntArrayBlockEncoding.NAME, new IntArrayBlockEncoding()) + .put(LongArrayBlockEncoding.NAME, new LongArrayBlockEncoding()) + .put(Int128ArrayBlockEncoding.NAME, new Int128ArrayBlockEncoding()) + .put(DictionaryBlockEncoding.NAME, new DictionaryBlockEncoding()) + .put(ArrayBlockEncoding.NAME, new ArrayBlockEncoding()) + .put(RowBlockEncoding.NAME, new RowBlockEncoding()) + .put(SingleRowBlockEncoding.NAME, new SingleRowBlockEncoding()) + .put(RunLengthBlockEncoding.NAME, new RunLengthBlockEncoding()) + .put(LazyBlockEncoding.NAME, new LazyBlockEncoding()) + .build(); + } + + private static String readLengthPrefixedString(SliceInput sliceInput) { + int length = sliceInput.readInt(); + byte[] bytes = new byte[length]; + sliceInput.readBytes(bytes); + return new String(bytes, StandardCharsets.UTF_8); + } + + private static void writeLengthPrefixedString(SliceOutput sliceOutput, String value) { + byte[] bytes = value.getBytes(UTF_8); + sliceOutput.writeInt(bytes.length); + sliceOutput.writeBytes(bytes); + } @Override - public Block readBlock(SliceInput input) { - return null; + public Block readBlock(SliceInput input) { + return blockEncodings.get(readLengthPrefixedString(input)).readBlock(this, input); } @Override public void writeBlock(SliceOutput output, Block block) { + Block readBlock = block; + while (true) { + String encodingName = readBlock.getEncodingName(); - } -} + BlockEncoding blockEncoding = blockEncodings.get(encodingName); + + Optional replacementBlock = blockEncoding.replacementBlockForWrite(readBlock); + if (replacementBlock.isPresent()) { + readBlock = replacementBlock.get(); + continue; + } + writeLengthPrefixedString(output, encodingName); + + blockEncoding.writeBlock(this, output, readBlock); + + break; + } + } +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/stub/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java b/omnidata/omnidata-spark-connector/stub/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java new file mode 100644 index 0000000000000000000000000000000000000000..f2c3ab769ed1f36b63752776789ab03d4ddce6c7 --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java @@ -0,0 +1,279 @@ +/* + * Copyright (C) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * OmniColumnVector stub code + * + * @since 2023-04-04 + */ +public class OmniColumnVector extends WritableColumnVector { + public OmniColumnVector(int capacity, DataType type, boolean isInitVec) { + super(capacity, type); + } + + @Override + public int getDictId(int rowId) { + return 0; + } + + @Override + protected void reserveInternal(int capacity) { + + } + + @Override + public void putNotNull(int rowId) { + + } + + @Override + public void putNull(int rowId) { + + } + + @Override + public void putNulls(int rowId, int count) { + + } + + @Override + public void putNotNulls(int rowId, int count) { + + } + + @Override + public void putBoolean(int rowId, boolean isValue) { + + } + + @Override + public void putBooleans(int rowId, int count, boolean isValue) { + + } + + @Override + public void putByte(int rowId, byte value) { + + } + + @Override + public void putBytes(int rowId, int count, byte value) { + + } + + @Override + public void putBytes(int rowId, int count, byte[] src, int srcIndex) { + + } + + @Override + public void putShort(int rowId, short value) { + + } + + @Override + public void putShorts(int rowId, int count, short value) { + + } + + @Override + public void putShorts(int rowId, int count, short[] src, int srcIndex) { + + } + + @Override + public void putShorts(int rowId, int count, byte[] src, int srcIndex) { + + } + + @Override + public void putInt(int rowId, int value) { + + } + + @Override + public void putInts(int rowId, int count, int value) { + + } + + @Override + public void putInts(int rowId, int count, int[] src, int srcIndex) { + + } + + @Override + public void putInts(int rowId, int count, byte[] src, int srcIndex) { + + } + + @Override + public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + + } + + @Override + public void putLong(int rowId, long value) { + + } + + @Override + public void putLongs(int rowId, int count, long value) { + + } + + @Override + public void putLongs(int rowId, int count, long[] src, int srcIndex) { + + } + + @Override + public void putLongs(int rowId, int count, byte[] src, int srcIndex) { + + } + + @Override + public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + + } + + @Override + public void putFloat(int rowId, float value) { + + } + + @Override + public void putFloats(int rowId, int count, float value) { + + } + + @Override + public void putFloats(int rowId, int count, float[] src, int srcIndex) { + + } + + @Override + public void putFloats(int rowId, int count, byte[] src, int srcIndex) { + + } + + @Override + public void putFloatsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + + } + + @Override + public void putDouble(int rowId, double value) { + + } + + @Override + public void putDoubles(int rowId, int count, double value) { + + } + + @Override + public void putDoubles(int rowId, int count, double[] src, int srcIndex) { + + } + + @Override + public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + + } + + @Override + public void putDoublesLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + + } + + @Override + public void putArray(int rowId, int offset, int length) { + + } + + @Override + public int putByteArray(int rowId, byte[] value, int offset, int count) { + return 0; + } + + @Override + protected UTF8String getBytesAsUTF8String(int rowId, int count) { + return null; + } + + @Override + public int getArrayLength(int rowId) { + return 0; + } + + @Override + public int getArrayOffset(int rowId) { + return 0; + } + + @Override + protected WritableColumnVector reserveNewColumn(int capacity, DataType type) { + return null; + } + + @Override + public boolean isNullAt(int rowId) { + return false; + } + + @Override + public boolean getBoolean(int rowId) { + return false; + } + + @Override + public byte getByte(int rowId) { + return 0; + } + + @Override + public short getShort(int rowId) { + return 0; + } + + @Override + public int getInt(int rowId) { + return 0; + } + + @Override + public long getLong(int rowId) { + return 0; + } + + @Override + public float getFloat(int rowId) { + return 0; + } + + @Override + public double getDouble(int rowId) { + return 0; + } +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarConditionProjectExec.scala b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarConditionProjectExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..aeed523e80ff0c28cb0457bb843b5b62ea8543f6 --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarConditionProjectExec.scala @@ -0,0 +1,20 @@ +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} + +case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], + condition: Expression, + child: SparkPlan) + extends UnaryExecNode + with AliasAwareOutputPartitioning + with AliasAwareOutputOrdering { + override protected def orderingExpressions: Seq[SortOrder] = ??? + + override protected def outputExpressions: Seq[NamedExpression] = ??? + + override protected def doExecute(): RDD[InternalRow] = ??? + + override def output: Seq[Attribute] = ??? +} diff --git a/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarFilterExec.scala b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarFilterExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..7ea123b3a207a8e5c5884c2375922c49bbbef074 --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarFilterExec.scala @@ -0,0 +1,13 @@ +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PredicateHelper} + +case class ColumnarFilterExec(condition: Expression, child: SparkPlan) + extends UnaryExecNode with PredicateHelper { + override protected def doExecute(): RDD[InternalRow] = ??? + + override def output: Seq[Attribute] = ??? + +} diff --git a/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..c65615b806eb5ef24d25765371e0d4471f478562 --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala @@ -0,0 +1,21 @@ +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec + +case class ColumnarHashAggregateExec( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends BaseAggregateExec + with AliasAwareOutputPartitioning { + override protected def doExecute(): RDD[InternalRow] = ??? + +} diff --git a/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarProjectExec.scala b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarProjectExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..8d4f53380b29e2a27d0b3ac74bbb804d95919f8f --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarProjectExec.scala @@ -0,0 +1,18 @@ +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} + +case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) + extends UnaryExecNode + with AliasAwareOutputPartitioning + with AliasAwareOutputOrdering { + override protected def orderingExpressions: Seq[SortOrder] = ??? + + override protected def outputExpressions: Seq[NamedExpression] = ??? + + override protected def doExecute(): RDD[InternalRow] = ??? + + override def output: Seq[Attribute] = ??? +} diff --git a/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..b1a1614c4c2fbd934a5f68961a232db72c834e9b --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -0,0 +1,28 @@ +package org.apache.spark.sql.execution + +import org.apache.spark.MapOutputStatistics +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeLike, ShuffleOrigin} + +import scala.concurrent.Future + +case class ColumnarShuffleExchangeExec( + override val outputPartitioning: Partitioning, + child: SparkPlan, + shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS) + extends ShuffleExchangeLike { + override def numMappers: Int = ??? + + override def numPartitions: Int = ??? + + override def mapOutputStatisticsFuture: Future[MapOutputStatistics] = ??? + + override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] = ??? + + override def runtimeStatistics: Statistics = ??? + + override protected def doExecute(): RDD[InternalRow] = ??? +} diff --git a/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..44f99015685e56766faae1f267fde2d6dd5baa84 --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala @@ -0,0 +1,16 @@ +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} + +case class ColumnarSortExec(sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan, + testSpillFrequency: Int = 0) + extends UnaryExecNode{ + override protected def doExecute(): RDD[InternalRow] = ??? + + override def output: Seq[Attribute] = ??? + +} diff --git a/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/OmniColumnarToRowExec.scala b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/OmniColumnarToRowExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..c803e59ada9d42eaf163394114b89ae44f8d330e --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/OmniColumnarToRowExec.scala @@ -0,0 +1,10 @@ +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class OmniColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransition { + override protected def doExecute(): RDD[InternalRow] = ??? + override def output: Seq[Attribute] = ??? +} diff --git a/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/RowToOmniColumnarExec.scala b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/RowToOmniColumnarExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..c80bb21c9ec8c8b3da74d8ce6a4cdeae4a987de4 --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/RowToOmniColumnarExec.scala @@ -0,0 +1,10 @@ +package org.apache.spark.sql.execution +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransition { + override protected def doExecute(): RDD[InternalRow] = ??? + + override def output: Seq[Attribute] = ??? +} diff --git a/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/join/ColumnarBroadcastHashJoinExec.scala b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/join/ColumnarBroadcastHashJoinExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..6608f939a6d936e17b4caeb96cad0d5067bb75c2 --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/join/ColumnarBroadcastHashJoinExec.scala @@ -0,0 +1,28 @@ +package org.apache.spark.sql.execution.joins + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} +import org.apache.spark.sql.catalyst.optimizer.BuildSide +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.execution.SparkPlan + +case class ColumnarBroadcastHashJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isNullAwareAntiJoin: Boolean = false, + projectList: Seq[NamedExpression] = Seq.empty) + extends HashJoin { + override protected def prepareRelation(ctx: CodegenContext): HashedRelationInfo = ??? + + override def inputRDDs(): Seq[RDD[InternalRow]] = ??? + + override protected def doExecute(): RDD[InternalRow] = ??? + +} diff --git a/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/join/ColumnarShuffledHashJoinExec.scala b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/join/ColumnarShuffledHashJoinExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..5602acab98807c244442f0a1ec086e1a23c2543a --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/join/ColumnarShuffledHashJoinExec.scala @@ -0,0 +1,28 @@ +package org.apache.spark.sql.execution.joins + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} +import org.apache.spark.sql.catalyst.optimizer.BuildSide +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.execution.SparkPlan + +case class ColumnarShuffledHashJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + projectList: Seq[NamedExpression] = Seq.empty) + extends HashJoin with ShuffledJoin { + override protected def prepareRelation(ctx: CodegenContext): HashedRelationInfo = ??? + + + override def inputRDDs(): Seq[RDD[InternalRow]] = ??? + + override protected def doExecute(): RDD[InternalRow] = ??? + +} diff --git a/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/join/ColumnarSortMergeJoinExec.scala b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/join/ColumnarSortMergeJoinExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..fe932fa6ed1427ce687618555621a945e29944fe --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/join/ColumnarSortMergeJoinExec.scala @@ -0,0 +1,27 @@ +package org.apache.spark.sql.execution.joins + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} + +case class ColumnarSortMergeJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isSkewJoin: Boolean = false, + projectList: Seq[NamedExpression] = Seq.empty) + extends ShuffledJoin with CodegenSupport { + override def inputRDDs(): Seq[RDD[InternalRow]] = ??? + + override protected def doProduce(ctx: CodegenContext): String = ??? + + override protected def doExecute(): RDD[InternalRow] = ??? + +} + diff --git a/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..8218b65d046cfeb9844c57458bf34e371f0d174f --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala @@ -0,0 +1,24 @@ +/* + * Copyright (C) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.util + +object SparkMemoryUtils { + def init(): Unit = {} +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..15ba1a3a371fc44f10fca29b250131a9fe436079 --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +object ReflectionUtils { + def invokeStatic(clazz: Class[_], methodName: String, args: (Class[_], AnyRef)*): AnyRef = { + null + } + +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala new file mode 100644 index 0000000000000000000000000000000000000000..8a59530f39ebb1811a0803494c4eaab25d56c0b9 --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.hadoop.hive.ql.Driver +import org.apache.spark.sql.SQLContext + +class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlContext) + extends Driver { +} diff --git a/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala new file mode 100644 index 0000000000000000000000000000000000000000..3f58112a12cfb8058dde3f81055c3dfeae9c764c --- /dev/null +++ b/omnidata/omnidata-spark-connector/stub/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext + +object SparkSQLEnv { + var sqlContext: SQLContext = _ + var sparkContext: SparkContext = _ + + def init(): Unit = {} + + def stop(): Unit = {} +} diff --git a/omnicache/omnicache-spark-extension/README.md b/omnimv/omnimv-spark-extension/README.md similarity index 79% rename from omnicache/omnicache-spark-extension/README.md rename to omnimv/omnimv-spark-extension/README.md index 2351940168b46db854fff0f931347bc9f058b3d3..c4e24f1d0ee3cc2c4dd7580f81d1770c5a53a65f 100644 --- a/omnicache/omnicache-spark-extension/README.md +++ b/omnimv/omnimv-spark-extension/README.md @@ -1,4 +1,4 @@ -# omnicache-spark-extension +# omnimv-spark-extension A SQL Engine Extension for Spark SQL to support Materialized View @@ -11,7 +11,7 @@ conditions. The Spark plugin is used to add materialized view management and execution plan rewriting capabilities, greatly improving Spark computing efficiency. -## Environment for building OmniCache +## Environment for building OmniMV ```shell # download @@ -22,13 +22,13 @@ tar -zxvf hadoop-3.1.1.tar.gz export HADOOP_HOME="${pwd}/haddoop-3.1.1" ``` -## Build OmniCache +## Build OmniMV -pull the OmniCache code and compile it to get the jar package +pull the OmniMV code and compile it to get the jar package ```shell git clone https://gitee.com/kunpengcompute/boostkit-bigdata.git -cd boostkit-bigdata/omnicache/omnicache-spark-extension -# This step can be compiled, tested and packaged to get plugin/boostkit-omnicache-spark-3.1.1-1.0.0.jar +cd boostkit-bigdata/omnimv/omnimv-spark-extension +# This step can be compiled, tested and packaged to get plugin/boostkit-omnimv-spark-${omniMV.version}.jar mvn clean package ``` diff --git a/omnicache/omnicache-spark-extension/build.sh b/omnimv/omnimv-spark-extension/build.sh similarity index 100% rename from omnicache/omnicache-spark-extension/build.sh rename to omnimv/omnimv-spark-extension/build.sh diff --git a/omnicache/omnicache-spark-extension/dev/checkstyle.xml b/omnimv/omnimv-spark-extension/dev/checkstyle.xml similarity index 100% rename from omnicache/omnicache-spark-extension/dev/checkstyle.xml rename to omnimv/omnimv-spark-extension/dev/checkstyle.xml diff --git a/omnicache/omnicache-spark-extension/log-parser/pom.xml b/omnimv/omnimv-spark-extension/log-parser/pom.xml similarity index 87% rename from omnicache/omnicache-spark-extension/log-parser/pom.xml rename to omnimv/omnimv-spark-extension/log-parser/pom.xml index 75f93ca239dde0b92a2ea43e91d56e56444b9560..57b18ac91187ec72013f8fa4baf0d10ddbc70ab0 100644 --- a/omnicache/omnicache-spark-extension/log-parser/pom.xml +++ b/omnimv/omnimv-spark-extension/log-parser/pom.xml @@ -5,31 +5,22 @@ com.huawei.kunpeng - boostkit-omnicache-spark-parent - 3.1.1-1.0.0 + boostkit-omnimv-spark-parent + ${omnimv.version} 4.0.0 - boostkit-omnicache-logparser-spark + boostkit-omnimv-logparser-spark jar - 3.1.1-1.0.0 + ${omnimv.version} log-parser - - 14.0.1 - - com.huawei.kunpeng - boostkit-omnicache-spark - 3.1.1-1.0.0 - - - com.google.guava - guava - ${guava.version} + boostkit-omnimv-spark + ${omnimv.version} org.apache.spark @@ -45,22 +36,6 @@ - - org.apache.spark - spark-core_${scala.binary.version} - test-jar - test - - - org.apache.hadoop - hadoop-client - - - org.apache.curator - curator-recipes - - - junit junit @@ -232,7 +207,7 @@ true true ${project.build.sourceEncoding} - true + ${scoverage.skip} diff --git a/omnicache/omnicache-spark-extension/log-parser/src/main/scala/org/apache/spark/deploy/history/LogsParser.scala b/omnimv/omnimv-spark-extension/log-parser/src/main/scala/org/apache/spark/deploy/history/LogsParser.scala similarity index 85% rename from omnicache/omnicache-spark-extension/log-parser/src/main/scala/org/apache/spark/deploy/history/LogsParser.scala rename to omnimv/omnimv-spark-extension/log-parser/src/main/scala/org/apache/spark/deploy/history/LogsParser.scala index 2b447335fec2b434fbe1e6e72ab311eb5bc36b8d..c71f130b58cd52009b503f69d353c4c9397eea45 100644 --- a/omnicache/omnicache-spark-extension/log-parser/src/main/scala/org/apache/spark/deploy/history/LogsParser.scala +++ b/omnimv/omnimv-spark-extension/log-parser/src/main/scala/org/apache/spark/deploy/history/LogsParser.scala @@ -17,11 +17,12 @@ package org.apache.spark.deploy.history -import com.huawei.boostkit.spark.util.RewriteLogger +import com.huawei.boostkit.spark.util.{KerberosUtil, RewriteLogger} import java.io.FileNotFoundException import java.text.SimpleDateFormat import java.util.ServiceLoader import java.util.regex.Pattern +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.json4s.DefaultFormats import org.json4s.jackson.Json @@ -43,10 +44,15 @@ import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} class LogsParser(conf: SparkConf, eventLogDir: String, outPutDir: String) extends RewriteLogger { private val LINE_SEPARATOR = "\n" - private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + private val hadoopConf = confLoad() // Visible for testing private[history] val fs: FileSystem = new Path(eventLogDir).getFileSystem(hadoopConf) + def confLoad(): Configuration = { + val configuration: Configuration = SparkHadoopUtil.get.newConfiguration(conf) + KerberosUtil.newConfiguration(configuration) + } + /** * parseAppHistoryLog * @@ -94,12 +100,17 @@ class LogsParser(conf: SparkConf, eventLogDir: String, outPutDir: String) extend } } } + val duration = if (uiData.completionTime.isDefined) { + (uiData.completionTime.get.getTime - uiData.submissionTime) + "ms" + } else { + "Unfinished" + } // write dot val graph: SparkPlanGraph = sqlStatusStore.planGraph(executionId) sqlStatusStore.planGraph(executionId) val metrics = sqlStatusStore.executionMetrics(executionId) - val node = getNodeInfo(graph) + val node = getNodeInfo(graph, metrics) val jsonMap = Map( "logName" -> appId, @@ -108,7 +119,8 @@ class LogsParser(conf: SparkConf, eventLogDir: String, outPutDir: String) extend "materialized views" -> mvs, "physical plan" -> planDesc, "dot metrics" -> graph.makeDotFile(metrics), - "node metrics" -> node) + "node metrics" -> node, + "duration" -> duration) jsons :+= jsonMap } } @@ -116,7 +128,7 @@ class LogsParser(conf: SparkConf, eventLogDir: String, outPutDir: String) extend case e: FileNotFoundException => throw e case e: Throwable => - logWarning(s"Failed to parseAppHistoryLog ${appId} for ${e.getMessage}") + logWarning(s"Failed to parseAppHistoryLog $appId for ${e.getMessage}") } jsons } @@ -163,29 +175,25 @@ class LogsParser(conf: SparkConf, eventLogDir: String, outPutDir: String) extend * @param graph SparkPlanGraph * @return NodeInfo */ - def getNodeInfo(graph: SparkPlanGraph): String = { + def getNodeInfo(graph: SparkPlanGraph, metricsValue: Map[Long, String]): String = { // write node - val tmpContext = new StringBuilder + val tmpContext = new mutable.StringBuilder tmpContext.append("[PlanMetric]") nextLine(tmpContext) graph.allNodes.foreach { node => tmpContext.append(s"id:${node.id} name:${node.name} desc:${node.desc}") nextLine(tmpContext) node.metrics.foreach { metric => - metric.toString - tmpContext.append("SQLPlanMetric(") - tmpContext.append(metric.name) - tmpContext.append(",") - if (metric.metricType == "timing") { - tmpContext.append(s"${metric.accumulatorId * 1000000} ns, ") - } else if (metric.metricType == "nsTiming") { - tmpContext.append(s"${metric.accumulatorId} ns, ") - } else { - tmpContext.append(s"${metric.accumulatorId}, ") + val value = metricsValue.get(metric.accumulatorId) + if (value.isDefined) { + tmpContext.append("SQLPlanMetric(") + .append(metric.name) + .append(",") + .append(getMetrics(value.get)).append(", ") + .append(metric.metricType) + .append(")") + nextLine(tmpContext) } - tmpContext.append(metric.metricType) - tmpContext.append(")") - nextLine(tmpContext) } nextLine(tmpContext) nextLine(tmpContext) @@ -206,13 +214,41 @@ class LogsParser(conf: SparkConf, eventLogDir: String, outPutDir: String) extend tmpContext.append(s"${cluster.nodes(i).id} ") } nextLine(tmpContext) - case node => + case _ => } nextLine(tmpContext) tmpContext.toString() } - def nextLine(context: StringBuilder): Unit = { + def getMetrics(context: String): String = { + val separator = '\n' + val detail = s"total (min, med, max (stageId: taskId))$separator" + if (!context.contains(detail)) { + return context + } + // get metrics like 'total (min, med, max (stageId: taskId))'. + val lines = context.split(separator) + val res = lines.map(_.replaceAll(" \\(", ", ") + .replace("stage ", "") + .replace("task ", "") + .replace(")", "") + .replace(":", ", ") + .split(",")).reduce((t1, t2) => { + val sb = new StringBuilder + sb.append('[') + for (i <- 0 until (t1.size)) { + sb.append(t1(i)).append(":").append(t2(i)) + if (i != t1.size - 1) { + sb.append(", ") + } + } + sb.append(']') + Array(sb.toString()) + }) + res(0) + } + + def nextLine(context: mutable.StringBuilder): Unit = { context.append(LINE_SEPARATOR) } @@ -325,6 +361,8 @@ arg1: output dir in hdfs, eg. hdfs://server1:9000/logParser arg2: log file to be parsed, eg. application_1646816941391_0115.lz4 */ object ParseLog extends RewriteLogger { + val regex = ".*application_[0-9]+_[0-9]+.*(\\.lz4)?$" + def main(args: Array[String]): Unit = { if (args == null || args.length != 3) { throw new RuntimeException("input params is invalid,such as below\n" + @@ -346,7 +384,6 @@ object ParseLog extends RewriteLogger { val logParser = new LogsParser(conf, sparkEventLogDir, outputDir) // file pattern - val regex = "^application_[0-9]+._[0-9]+.lz4$" val pattern = Pattern.compile(regex) val matcher = pattern.matcher(logName) if (matcher.find) { @@ -391,7 +428,7 @@ object ParseLogs extends RewriteLogger { val logParser = new LogsParser(conf, sparkEventLogDir, outputDir) // file pattern - val regex = "^application_[0-9]+._[0-9]+.lz4$" + val regex = ParseLog.regex val pattern = Pattern.compile(regex) val dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm") val appIds = logParser.listAppHistoryLogs(pattern, diff --git a/omnicache/omnicache-spark-extension/log-parser/src/test/resources/application_1663257594501_0003.lz4 b/omnimv/omnimv-spark-extension/log-parser/src/test/resources/application_1663257594501_0003.lz4 similarity index 100% rename from omnicache/omnicache-spark-extension/log-parser/src/test/resources/application_1663257594501_0003.lz4 rename to omnimv/omnimv-spark-extension/log-parser/src/test/resources/application_1663257594501_0003.lz4 diff --git a/omnicache/omnicache-spark-extension/log-parser/src/test/scala/org/apache/spark/deploy/history/LogsParserSuite.scala b/omnimv/omnimv-spark-extension/log-parser/src/test/scala/org/apache/spark/deploy/history/LogsParserSuite.scala similarity index 94% rename from omnicache/omnicache-spark-extension/log-parser/src/test/scala/org/apache/spark/deploy/history/LogsParserSuite.scala rename to omnimv/omnimv-spark-extension/log-parser/src/test/scala/org/apache/spark/deploy/history/LogsParserSuite.scala index b9c60d0566c99851e0e5b96bb1f6a432025e031e..edc506761806ba89bb715ad91947bd4074c9adff 100644 --- a/omnicache/omnicache-spark-extension/log-parser/src/test/scala/org/apache/spark/deploy/history/LogsParserSuite.scala +++ b/omnimv/omnimv-spark-extension/log-parser/src/test/scala/org/apache/spark/deploy/history/LogsParserSuite.scala @@ -25,10 +25,12 @@ import org.apache.commons.io.IOUtils import org.apache.commons.lang3.time.DateUtils import org.json4s.DefaultFormats import org.json4s.jackson.Json +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} +import org.scalatest.funsuite.AnyFunSuite -import org.apache.spark.SparkFunSuite - -class LogsParserSuite extends SparkFunSuite { +class LogsParserSuite extends AnyFunSuite + with BeforeAndAfterAll + with BeforeAndAfterEach { test("parse") { val path = this.getClass.getResource("/").getPath diff --git a/omnicache/omnicache-spark-extension/plugin/pom.xml b/omnimv/omnimv-spark-extension/plugin/pom.xml similarity index 91% rename from omnicache/omnicache-spark-extension/plugin/pom.xml rename to omnimv/omnimv-spark-extension/plugin/pom.xml index 721879bddbf536cc7e37ca8d63dbd4cec3f93b0e..349ad5c6fdde350625256ff12f31d4c903d92c7d 100644 --- a/omnicache/omnicache-spark-extension/plugin/pom.xml +++ b/omnimv/omnimv-spark-extension/plugin/pom.xml @@ -5,14 +5,14 @@ com.huawei.kunpeng - boostkit-omnicache-spark-parent - 3.1.1-1.0.0 + boostkit-omnimv-spark-parent + ${omnimv.version} 4.0.0 - boostkit-omnicache-spark + boostkit-omnimv-spark jar - 3.1.1-1.0.0 + ${omnimv.version} plugin @@ -21,6 +21,10 @@ + + com.esotericsoftware + kryo-shaded + com.google.guava guava @@ -40,22 +44,6 @@ - - org.apache.spark - spark-core_${scala.binary.version} - test-jar - test - - - org.apache.hadoop - hadoop-client - - - org.apache.curator - curator-recipes - - - junit junit @@ -227,7 +215,7 @@ true true ${project.build.sourceEncoding} - true + ${scoverage.skip} diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/antlr4/org/apache/spark/sql/catalyst/parser/OmniCacheSqlExtensions.g4 b/omnimv/omnimv-spark-extension/plugin/src/main/antlr4/org/apache/spark/sql/catalyst/parser/OmniMVSqlExtensions.g4 similarity index 98% rename from omnicache/omnicache-spark-extension/plugin/src/main/antlr4/org/apache/spark/sql/catalyst/parser/OmniCacheSqlExtensions.g4 rename to omnimv/omnimv-spark-extension/plugin/src/main/antlr4/org/apache/spark/sql/catalyst/parser/OmniMVSqlExtensions.g4 index ea797d36f6d7e1a4cd35a67d277ecd3c4e80e066..6aca30b498bf255d700bebb05611f907c29c5986 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/antlr4/org/apache/spark/sql/catalyst/parser/OmniCacheSqlExtensions.g4 +++ b/omnimv/omnimv-spark-extension/plugin/src/main/antlr4/org/apache/spark/sql/catalyst/parser/OmniMVSqlExtensions.g4 @@ -15,7 +15,7 @@ * This file is an adaptation of Spark's spark/sql/catalyst/src/main/antlr4/org/apache/spark/sql/parser/SqlBase.g4 grammar. */ -grammar OmniCacheSqlExtensions; +grammar OmniMVSqlExtensions; @parser::members { /** @@ -131,8 +131,22 @@ statement statement #explain | ALTER MATERIALIZED VIEW multipartIdentifier (ENABLE|DISABLE) REWRITE #alterRewriteMV + | WASHOUT (ALL)? MATERIALIZED VIEW (washOutExpressions)? #washOutMV ; +washOutExpressions + : USING washOutStrategy (',' washOutStrategy)* + ; + +washOutStrategy + : UNUSED_DAYS (washOutValue)? + | RESERVE_QUANTITY_BY_VIEW_COUNT (washOutValue)? + | DROP_QUANTITY_BY_SPACE_CONSUMED (washOutValue)? + ; + +washOutValue + : INTEGER_VALUE + ; createMVHeader : CREATE MATERIALIZED VIEW (IF NOT EXISTS)? multipartIdentifier @@ -1370,6 +1384,10 @@ ADD: 'ADD'; AFTER: 'AFTER'; ALL: 'ALL'; ALTER: 'ALTER'; +WASHOUT: 'WASH OUT'; +UNUSED_DAYS: 'UNUSED_DAYS'; +RESERVE_QUANTITY_BY_VIEW_COUNT: 'RESERVE_QUANTITY_BY_VIEW_COUNT'; +DROP_QUANTITY_BY_SPACE_CONSUMED: 'DROP_QUANTITY_BY_SPACE_CONSUMED'; ANALYZE: 'ANALYZE'; AND: 'AND'; ANTI: 'ANTI'; diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/runtime/AbstractImmutableList.java b/omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/runtime/AbstractImmutableList.java similarity index 100% rename from omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/runtime/AbstractImmutableList.java rename to omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/runtime/AbstractImmutableList.java diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/runtime/ConsList.java b/omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/runtime/ConsList.java similarity index 100% rename from omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/runtime/ConsList.java rename to omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/runtime/ConsList.java diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/util/Pair.java b/omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/util/Pair.java similarity index 100% rename from omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/util/Pair.java rename to omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/util/Pair.java diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/util/RangeUtil.java b/omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/util/RangeUtil.java similarity index 97% rename from omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/util/RangeUtil.java rename to omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/util/RangeUtil.java index 0cca0f630476eb5632df4db71c828abe8979e8f9..898401d64c00051635f22e3c327f8a2819708ed6 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/util/RangeUtil.java +++ b/omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/util/RangeUtil.java @@ -34,7 +34,12 @@ import java.util.Set; import static org.apache.spark.sql.types.DataTypes.NullType; import static org.apache.spark.sql.types.DataTypes.BooleanType; -public class RangeUtil { +public final class RangeUtil { + + private RangeUtil() { + throw new IllegalStateException("Utility class"); + } + public static Expression simplifyUsingPredicates(Expression expr, Set pulledUpPredicates) { Option opt = ExprOptUtil.createComparison(expr); if (opt.isEmpty() || opt.get().literal().value() == null) { @@ -180,7 +185,7 @@ public class RangeUtil { Pair.of(Range.singleton(v0), ImmutableList.of(predicate))); // remove for (Expression e : p.right) { - replaceAllExpression(terms, e, Literal.TrueLiteral()); + replaceExpression(terms, e, Literal.TrueLiteral()); } break; } @@ -337,7 +342,7 @@ public class RangeUtil { ImmutableList.Builder newBounds = ImmutableList.builder(); for (Expression e : p.right) { if (ExprOptUtil.isUpperBound(e)) { - replaceAllExpression(terms, e, Literal.TrueLiteral()); + replaceExpression(terms, e, Literal.TrueLiteral()); } else { newBounds.add(e); } @@ -348,7 +353,7 @@ public class RangeUtil { ImmutableList.Builder newBounds = ImmutableList.builder(); for (Expression e : p.right) { if (ExprOptUtil.isLowerBound(e)) { - replaceAllExpression(terms, e, Literal.TrueLiteral()); + replaceExpression(terms, e, Literal.TrueLiteral()); } else { newBounds.add(e); } @@ -361,10 +366,10 @@ public class RangeUtil { return null; } - private static boolean replaceAllExpression(List terms, Expression oldVal, Expression newVal) { + private static boolean replaceExpression(List terms, Expression oldVal, Expression newVal) { boolean result = false; for (int i = 0; i < terms.size(); i++) { - if (terms.get(i).equals(oldVal)) { + if (terms.get(i) == oldVal) { terms.set(i, newVal); result = true; } diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/DefaultDirectedGraph.java b/omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/DefaultDirectedGraph.java similarity index 100% rename from omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/DefaultDirectedGraph.java rename to omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/DefaultDirectedGraph.java diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/DefaultEdge.java b/omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/DefaultEdge.java similarity index 100% rename from omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/DefaultEdge.java rename to omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/DefaultEdge.java diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/DirectedGraph.java b/omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/DirectedGraph.java similarity index 100% rename from omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/DirectedGraph.java rename to omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/DirectedGraph.java diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/Graphs.java b/omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/Graphs.java similarity index 100% rename from omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/Graphs.java rename to omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/Graphs.java diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/TopologicalOrderIterator.java b/omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/TopologicalOrderIterator.java similarity index 100% rename from omnicache/omnicache-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/TopologicalOrderIterator.java rename to omnimv/omnimv-spark-extension/plugin/src/main/java/org/apache/calcite/util/graph/TopologicalOrderIterator.java diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/OmniCache.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/OmniMV.scala similarity index 76% rename from omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/OmniCache.scala rename to omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/OmniMV.scala index b6feb1ab8281d2beb502f1b0a75a81397d3fa26d..aab01b4b0252b1e12e4a2c69a8701f4c2c780d49 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/OmniCache.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/OmniMV.scala @@ -20,26 +20,26 @@ package com.huawei.boostkit.spark import com.huawei.boostkit.spark.util.RewriteLogger import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} -import org.apache.spark.sql.catalyst.optimizer.OmniCacheOptimizer -import org.apache.spark.sql.catalyst.parser.OmniCacheExtensionSqlParser +import org.apache.spark.sql.catalyst.optimizer.OmniMVOptimizer +import org.apache.spark.sql.catalyst.parser.OmniMVExtensionSqlParser import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -class OmniCache extends (SparkSessionExtensions => Unit) with RewriteLogger { +class OmniMV extends (SparkSessionExtensions => Unit) with RewriteLogger { override def apply(extensions: SparkSessionExtensions): Unit = { - // OmniCache internal parser + // OmniMV internal parser extensions.injectParser { case (spark, parser) => - new OmniCacheExtensionSqlParser(spark, parser) + new OmniMVExtensionSqlParser(spark, parser) } - // OmniCache optimizer rules + // OmniMV optimizer rules extensions.injectPostHocResolutionRule { (session: SparkSession) => - OmniCacheOptimizerRule(session) + OmniMVOptimizerRule(session) } } } -case class OmniCacheOptimizerRule(session: SparkSession) extends Rule[LogicalPlan] { +case class OmniMVOptimizerRule(session: SparkSession) extends Rule[LogicalPlan] { self => var notAdded = true @@ -53,7 +53,7 @@ case class OmniCacheOptimizerRule(session: SparkSession) extends Rule[LogicalPla val field = sessionState.getClass.getDeclaredField("optimizer") field.setAccessible(true) field.set(sessionState, - OmniCacheOptimizer(session, sessionState.optimizer)) + OmniMVOptimizer(session, sessionState.optimizer)) } } } diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/conf/OmniCachePluginConfig.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/conf/OmniMVPluginConfig.scala similarity index 37% rename from omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/conf/OmniCachePluginConfig.scala rename to omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/conf/OmniMVPluginConfig.scala index 77cacf0670244d3b0293796cb7b963cd266882b1..ad271e199ce9a5d71ff43952f1d43b4743684a95 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/conf/OmniCachePluginConfig.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/conf/OmniMVPluginConfig.scala @@ -26,67 +26,143 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.SQLConf -class OmniCachePluginConfig(conf: SQLConf) { +class OmniMVPluginConfig(conf: SQLConf) { - // enable or disable OmniCache - def enableOmniCache: Boolean = conf - .getConfString("spark.sql.omnicache.enable", "true").toBoolean + // enable or disable OmniMV + def enableOmniMV: Boolean = conf + .getConfString("spark.sql.omnimv.enable", "true").toBoolean // show mv querySql length def showMVQuerySqlLen: Int = conf - .getConfString("spark.sql.omnicache.show.length", "50").toInt + .getConfString("spark.sql.omnimv.show.length", "50").toInt - // database where create OmniCache - val omniCacheDB: String = conf - .getConfString("spark.sql.omnicache.db", "default") + // database where create OmniMV, like omnimv,omnimv1 + def omniMVDB: String = conf + .getConfString("spark.sql.omnimv.dbs", "") // rewrite cur match mv def curMatchMV: String = conf - .getConfString("spark.sql.omnicache.cur.match.mv", "") + .getConfString("spark.sql.omnimv.cur.match.mv", "") def setCurMatchMV(mv: String): Unit = { - conf.setConfString("spark.sql.omnicache.cur.match.mv", mv) + conf.setConfString("spark.sql.omnimv.cur.match.mv", mv) } - val defaultDataSource: String = conf - .getConfString("spark.sql.omnicache.default.datasource", "orc") + // mv table datasource + def defaultDataSource: String = conf + .getConfString("spark.sql.omnimv.default.datasource", "orc") val dataSourceSet: Set[String] = Set("orc", "parquet") + // omnimv loglevel def logLevel: String = conf - .getConfString("spark.sql.omnicache.logLevel", "DEBUG") + .getConfString("spark.sql.omnimv.logLevel", "DEBUG") .toUpperCase(Locale.ROOT) + + // set parsed sql as JobDescription + def enableSqlLog: Boolean = conf + .getConfString("spark.sql.omnimv.log.enable", "true") + .toBoolean + + // omnimv metadata path + def metadataPath: String = conf + .getConfString("spark.sql.omnimv.metadata.path", "/user/omnimv/metadata") + + // enable omnimv init by query + def enableMetadataInitByQuery: Boolean = conf + .getConfString("spark.sql.omnimv.metadata.initbyquery.enable", "false") + .toBoolean + + // metadata index tail lines + def metadataIndexTailLines: Long = conf + .getConfString("spark.sql.omnimv.metadata.index.tail.lines", "5") + .toLong + + // Minimum unused time required for wash out. The default unit is "day". + def minimumUnusedDaysForWashOut: Int = conf + .getConfString("spark.sql.omnimv.washout.unused.day", "30") + .toInt + + // The number of materialized views to be reserved. + def reserveViewQuantityByViewCount: Int = conf + .getConfString("spark.sql.omnimv.washout.reserve.quantity.byViewCnt", "25") + .toInt + + def dropViewQuantityBySpaceConsumed: Int = conf + .getConfString("spark.sql.omnimv.washout.drop.quantity.bySpaceConsumed", "3") + .toInt + + // The default unit is "day". + def autoWashOutTimeInterval: Int = conf + .getConfString("spark.sql.omnimv.washout.automatic.time.interval", "35") + .toInt + + // Check "auto wash out" at intervals during the same session. The default unit is "second". + def autoCheckWashOutTimeInterval: Int = conf + .getConfString("spark.sql.omnimv.washout.automatic.checkTime.interval", "3600") + .toInt + + // The minimum number of views that trigger automatic wash out. + def automaticWashOutMinimumViewQuantity: Int = conf + .getConfString("spark.sql.omnimv.washout.automatic.view.quantity", "20") + .toInt + + def enableAutoWashOut: Boolean = conf + .getConfString("spark.sql.omnimv.washout.automatic.enable", "false") + .toBoolean + } -object OmniCachePluginConfig { +object OmniMVPluginConfig { + // mv if enable for rewrite + val MV_REWRITE_ENABLED = "spark.omnimv.rewrite.enable" - val MV_REWRITE_ENABLED = "spark.omnicache.rewrite.enable" + // mv if enable for rewrite when update + val MV_UPDATE_REWRITE_ENABLED = "spark.omnimv.update.rewrite.enable" - val MV_UPDATE_REWRITE_ENABLED = "spark.omnicache.update.rewrite.enable" + // mv query original sql + val MV_QUERY_ORIGINAL_SQL = "spark.omnimv.query.sql.original" - val MV_QUERY_ORIGINAL_SQL = "spark.omnicache.query.sql.original" + // mv query original sql exec db + val MV_QUERY_ORIGINAL_SQL_CUR_DB = "spark.omnimv.query.sql.cur.db" - val MV_QUERY_ORIGINAL_SQL_CUR_DB = "spark.omnicache.query.sql.cur.db" + // mv latest update time + val MV_LATEST_UPDATE_TIME = "spark.omnimv.latest.update.time" - val MV_LATEST_UPDATE_TIME = "spark.omnicache.latest.update.time" + // spark job descriptor + val SPARK_JOB_DESCRIPTION = "spark.job.description" - var ins: Option[OmniCachePluginConfig] = None + var ins: Option[OmniMVPluginConfig] = None - def getConf: OmniCachePluginConfig = synchronized { + def getConf: OmniMVPluginConfig = synchronized { if (ins.isEmpty) { ins = Some(getSessionConf) } ins.get } - def getSessionConf: OmniCachePluginConfig = { - new OmniCachePluginConfig(SQLConf.get) + def getSessionConf: OmniMVPluginConfig = { + new OmniMVPluginConfig(SQLConf.get) } + /** + * + * check if table is mv + * + * @param catalogTable catalogTable + * @return true:is mv; false:is not mv + */ def isMV(catalogTable: CatalogTable): Boolean = { catalogTable.properties.contains(MV_QUERY_ORIGINAL_SQL) } + /** + * check if mv is in update + * + * @param spark spark + * @param quotedMvName quotedMvName + * @return true:is in update; false:is not in update + */ def isMVInUpdate(spark: SparkSession, quotedMvName: String): Boolean = { val names = quotedMvName.replaceAll("`", "") .split("\\.").toSeq @@ -95,6 +171,12 @@ object OmniCachePluginConfig { !catalogTable.properties.getOrElse(MV_UPDATE_REWRITE_ENABLED, "true").toBoolean } + /** + * check if mv is in update + * + * @param viewTablePlan viewTablePlan + * @return true:is in update; false:is not in update + */ def isMVInUpdate(viewTablePlan: LogicalPlan): Boolean = { val logicalRelation = viewTablePlan.asInstanceOf[LogicalRelation] !logicalRelation.catalogTable.get diff --git a/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/exception/OmniMVException.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/exception/OmniMVException.scala new file mode 100644 index 0000000000000000000000000000000000000000..2985d3f2367c2a2d2c28560fbbd307f7dfecbcfb --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/exception/OmniMVException.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package com.huawei.boostkit.spark.exception + +case class OmniMVException(exInfo: String) extends RuntimeException diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ExprOptUtil.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ExprOptUtil.scala similarity index 94% rename from omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ExprOptUtil.scala rename to omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ExprOptUtil.scala index 1174930a7de529368eecf6a225a0b69eb37a78c8..86a9f6017e77bd0e8378d30a9c601047f5823712 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ExprOptUtil.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ExprOptUtil.scala @@ -91,11 +91,12 @@ object ExprOptUtil { if (expr == null) { return } - if (expr.isInstanceOf[And]) { - decomposeConjunctions(expr.asInstanceOf[And].left, buf) - decomposeConjunctions(expr.asInstanceOf[And].right, buf) - } else { - buf.+=(expr) + expr match { + case and: And => + decomposeConjunctions(and.left, buf) + decomposeConjunctions(and.right, buf) + case _ => + buf.+=(expr) } } @@ -130,7 +131,7 @@ object ExprOptUtil { makeOr(terms, nullOnEmpty) } - def makeOr(terms: Seq[Expression], nullOnEmpty: Boolean): Expression = { + private def makeOr(terms: Seq[Expression], nullOnEmpty: Boolean): Expression = { if (terms.isEmpty) { if (nullOnEmpty) null else Literal(false, BooleanType) } else if (terms.size == 1) { @@ -149,7 +150,7 @@ object ExprOptUtil { } } - def makeAnd(terms: Seq[Expression], nullOnEmpty: Boolean): Expression = { + private def makeAnd(terms: Seq[Expression], nullOnEmpty: Boolean): Expression = { if (terms.isEmpty) { if (nullOnEmpty) null else Literal(true, BooleanType) } else if (terms.size == 1) { @@ -192,28 +193,28 @@ object ExprOptUtil { } } - def isLiteralFalse(e: Expression): Boolean = { + private def isLiteralFalse(e: Expression): Boolean = { e.isInstanceOf[Literal] && e.sql.equals("false") } - def isLiteralTrue(e: Expression): Boolean = { + private def isLiteralTrue(e: Expression): Boolean = { e.isInstanceOf[Literal] && e.sql.equals("true") } /** - * @return Whether the expression.sql in {@code srcTerms} - * contains all expression.sql in {@code dstTerms} + * @return Whether the expression.sql in srcTerms + * contains all expression.sql in stTerms */ def containsAllSql(srcTerms: Set[Expression], dstTerms: Set[Expression]): Boolean = { if (dstTerms.isEmpty || srcTerms.isEmpty) { return false } var sql: mutable.Buffer[String] = mutable.Buffer() - for (srcTerm <- srcTerms) { + for (srcTerm <- srcTerms.map(RewriteHelper.canonicalize)) { sql.+=(srcTerm.sql) } val sqlSet = sql.toSet - for (dstTerm <- dstTerms) { + for (dstTerm <- dstTerms.map(RewriteHelper.canonicalize)) { if (!sqlSet.contains(dstTerm.sql)) { return false } @@ -222,8 +223,8 @@ object ExprOptUtil { } /** - * @return Whether the expression.sql in {@code srcTerms} - * contains at least one expression.sql in {@code dstTerms} + * @return Whether the expression.sql in srcTerms + * contains at least one expression.sql in dstTerms */ def containsSql(srcTerms: Set[Expression], dstTerms: Set[Expression]): Boolean = { if (dstTerms.isEmpty || srcTerms.isEmpty) { @@ -293,10 +294,10 @@ object ExprOptUtil { /** Returns the kind that you get if you apply NOT to this kind. * - *

For example, {@code IS_NOT_NULL.negate()} returns {@link #IS_NULL}. + *

For example, {@code IS_NOT_NULL.negate()} returns {@link # IS_NULL}. * - *

For {@link #IS_TRUE}, {@link #IS_FALSE}, {@link #IS_NOT_TRUE}, - * {@link #IS_NOT_FALSE}, nullable inputs need to be treated carefully. + *

For {@link # IS_TRUE}, {@link # IS_FALSE}, {@link # IS_NOT_TRUE}, + * {@link # IS_NOT_FALSE}, nullable inputs need to be treated carefully. * *

{@code NOT(IS_TRUE(null))} = {@code NOT(false)} = {@code true}, * while {@code IS_FALSE(null)} = {@code false}, @@ -611,6 +612,8 @@ case class EquivalenceClasses() { } cacheEquivalenceClasses } + + override def toString: String = nodeToEquivalenceClass.toString() } object EquivalenceClasses { diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ExprSimplifier.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ExprSimplifier.scala similarity index 91% rename from omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ExprSimplifier.scala rename to omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ExprSimplifier.scala index 5cb7d19258cfc56f95ea0f7bebfe4847e566ef40..7f28ba1920e536f06caa66d4371da1077325f408 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ExprSimplifier.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ExprSimplifier.scala @@ -27,7 +27,9 @@ import scala.util.control.Breaks import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteTime import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.{BooleanType, DataType, NullType} case class ExprSimplifier(unknownAsFalse: Boolean, @@ -472,6 +474,14 @@ case class ExprSimplifier(unknownAsFalse: Boolean, for (orOp <- orsOperands) { breaks3.breakable { val ors = decomposeDisjunctions(orOp).toSet + val others = terms.filter(!_.eq(orOp)).toSet + for (or <- ors) { + if (containsAllSql(others, conjunctions(or).toSet)) { + terms.-=(orOp) + breaks3.break() + } + } + for (term <- terms) { // Excluding self-simplification if (!term.eq(orOp)) { @@ -664,49 +674,51 @@ case class ExprSimplifier(unknownAsFalse: Boolean, object ExprSimplifier extends PredicateHelper { // Spark native simplification rules to be executed before this simplification - val frontRules = Seq(SimplifyCasts, ConstantFolding, UnwrapCastInBinaryComparison, ColumnPruning) + val frontRules: Seq[Rule[LogicalPlan]] = Seq() // simplify condition with pulledUpPredicates. def simplify(logicalPlan: LogicalPlan): LogicalPlan = { - val originPredicates: mutable.ArrayBuffer[Expression] = ArrayBuffer() - val normalizeLogicalPlan = RewriteHelper.normalizePlan(logicalPlan) - normalizeLogicalPlan foreach { - case Filter(condition, _) => - originPredicates ++= splitConjunctivePredicates(condition) - case Join(_, _, _, condition, _) if condition.isDefined => - originPredicates ++= splitConjunctivePredicates(condition.get) - case _ => - } - val inferredPlan = InferFiltersFromConstraints.apply(normalizeLogicalPlan) - val inferredPredicates: mutable.ArrayBuffer[Expression] = mutable.ArrayBuffer() - inferredPlan foreach { - case Filter(condition, _) => - inferredPredicates ++= splitConjunctivePredicates(condition) - case Join(_, _, _, condition, _) if condition.isDefined => - inferredPredicates ++= splitConjunctivePredicates(condition.get) - case _ => - } - val pulledUpPredicates: Set[Expression] = inferredPredicates.toSet -- originPredicates.toSet - // front Spark native optimize - var optPlan: LogicalPlan = normalizeLogicalPlan - for (rule <- frontRules) { - optPlan = rule.apply(optPlan) - } - optPlan transform { - case Filter(condition: Expression, child: LogicalPlan) => - val simplifyExpr = ExprSimplifier(true, pulledUpPredicates).simplify(condition) - Filter(simplifyExpr, child) - case Join(left, right, joinType, condition, hint) if condition.isDefined => - val simplifyExpr = ExprSimplifier(true, pulledUpPredicates).simplify(condition.get) - Join(left, right, joinType, Some(simplifyExpr), hint) - case other@_ => - other + RewriteTime.withTimeStat("ExprSimplifier.simplify") { + val originPredicates: mutable.ArrayBuffer[Expression] = ArrayBuffer() + val normalizeLogicalPlan = logicalPlan + normalizeLogicalPlan foreach { + case Filter(condition, _) => + originPredicates ++= splitConjunctivePredicates(condition) + case Join(_, _, _, condition, _) if condition.isDefined => + originPredicates ++= splitConjunctivePredicates(condition.get) + case _ => + } + val inferredPlan = InferFiltersFromConstraints.apply(normalizeLogicalPlan) + val inferredPredicates: mutable.ArrayBuffer[Expression] = mutable.ArrayBuffer() + inferredPlan foreach { + case Filter(condition, _) => + inferredPredicates ++= splitConjunctivePredicates(condition) + case Join(_, _, _, condition, _) if condition.isDefined => + inferredPredicates ++= splitConjunctivePredicates(condition.get) + case _ => + } + val pulledUpPredicates: Set[Expression] = inferredPredicates.toSet -- originPredicates.toSet + // front Spark native optimize + var optPlan: LogicalPlan = normalizeLogicalPlan + for (rule <- frontRules) { + optPlan = rule.apply(optPlan) + } + optPlan transform { + case Filter(condition: Expression, child: LogicalPlan) => + val simplifyExpr = ExprSimplifier(true, pulledUpPredicates).simplify(condition) + Filter(simplifyExpr, child) + case Join(left, right, joinType, condition, hint) if condition.isDefined => + val simplifyExpr = ExprSimplifier(true, pulledUpPredicates).simplify(condition.get) + Join(left, right, joinType, Some(simplifyExpr), hint) + case other@_ => + other + } } } // simplify condition without pulledUpPredicates. def simplify(expr: Expression): Expression = { - val fakePlan = simplify(Filter(expr, OneRowRelation())) - fakePlan.asInstanceOf[Filter].condition + val fakePlan = simplify(Filter(RewriteHelper.canonicalize(expr), OneRowRelation())) + RewriteHelper.canonicalize(fakePlan.asInstanceOf[Filter].condition) } } diff --git a/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/KerberosUtil.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/KerberosUtil.scala new file mode 100644 index 0000000000000000000000000000000000000000..6f60d5a23029aa683b0278b59c1c549c9f46d549 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/KerberosUtil.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.util + +import java.io.File + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.security.UserGroupInformation + +import org.apache.spark.sql.SparkSession + +object KerberosUtil { + + /** + * new configuration from spark + */ + def newConfiguration(spark: SparkSession): Configuration = { + val configuration: Configuration = spark.sessionState.newHadoopConf() + newConfiguration(configuration) + } + + /** + * new configuration from configuration + */ + def newConfiguration(configuration: Configuration): Configuration = { + val xmls = Seq("hdfs-site.xml", "core-site.xml") + val xmlDir = System.getProperty("omnimv.hdfs_conf", ".") + xmls.foreach { xml => + val file = new File(xmlDir, xml) + if (file.exists()) { + configuration.addResource(new Path(file.getAbsolutePath)) + } + } + + // security mode + if ("kerberos".equalsIgnoreCase(configuration.get("hadoop.security.authentication"))) { + val krb5Conf = System.getProperty("omnimv.krb5_conf", "/etc/krb5.conf") + System.setProperty("java.security.krb5.conf", krb5Conf) + val principal = System.getProperty("omnimv.principal") + val keytab = System.getProperty("omnimv.keytab") + if (principal == null || keytab == null) { + throw new RuntimeException("omnimv.principal or omnimv.keytab cannot be null") + } + System.setProperty("java.security.krb5.conf", krb5Conf) + UserGroupInformation.setConfiguration(configuration) + UserGroupInformation.loginUserFromKeytab(principal, keytab) + } + configuration + } +} diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/RewriteHelper.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/RewriteHelper.scala similarity index 51% rename from omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/RewriteHelper.scala rename to omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/RewriteHelper.scala index 2ef67e5361bb8f6920bce59c8519e53eedfd1c1b..b221c465ad105e33cc28a734e8bc6b968cdaafad 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/RewriteHelper.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/RewriteHelper.scala @@ -18,18 +18,25 @@ package com.huawei.boostkit.spark.util import com.google.common.collect.{ArrayListMultimap, BiMap, HashBiMap, Multimap} -import com.huawei.boostkit.spark.conf.OmniCachePluginConfig +import com.huawei.boostkit.spark.conf.OmniMVPluginConfig import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteTime +import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.SQLConf + trait RewriteHelper extends PredicateHelper with RewriteLogger { + type ViewMetadataPackageType = (String, LogicalPlan, LogicalPlan) + val SESSION_CATALOG_NAME: String = "spark_catalog" val EMPTY_BIMAP: HashBiMap[String, String] = HashBiMap.create[String, String]() @@ -37,21 +44,27 @@ trait RewriteHelper extends PredicateHelper with RewriteLogger { mutable.Set[ExpressionEqual]] = Map[ExpressionEqual, mutable.Set[ExpressionEqual]]() val EMPTY_MULTIMAP: Multimap[Int, Int] = ArrayListMultimap.create[Int, Int]() - def mergeConjunctiveExpressions(e: Seq[Expression]): Expression = { - if (e.isEmpty) { + /** + * merge expressions by and + */ + def mergeConjunctiveExpressions(exprs: Seq[Expression]): Expression = { + if (exprs.isEmpty) { return Literal.TrueLiteral } - if (e.size == 1) { - return e.head + if (exprs.size == 1) { + return exprs.head } - e.reduce { (a, b) => + exprs.reduce { (a, b) => And(a, b) } } - def fillQualifier(logicalPlan: LogicalPlan, + /** + * fill attr's qualifier + */ + def fillQualifier(plan: LogicalPlan, exprIdToQualifier: mutable.HashMap[ExprId, AttributeReference]): LogicalPlan = { - val newLogicalPlan = logicalPlan.transform { + val newLogicalPlan = plan.transform { case plan => plan.transformExpressions { case a: AttributeReference => @@ -66,10 +79,13 @@ trait RewriteHelper extends PredicateHelper with RewriteLogger { newLogicalPlan } + /** + * fill viewTablePlan's attr's qualifier by viewQueryPlan + */ def mapTablePlanAttrToQuery(viewTablePlan: LogicalPlan, viewQueryPlan: LogicalPlan): LogicalPlan = { // map by index - val topProjectList: Seq[NamedExpression] = viewQueryPlan match { + var topProjectList: Seq[NamedExpression] = viewQueryPlan match { case Project(projectList, _) => projectList case Aggregate(_, aggregateExpressions, _) => @@ -93,8 +109,12 @@ trait RewriteHelper extends PredicateHelper with RewriteLogger { fillQualifier(viewTablePlan, exprIdToQualifier) } - def extractTopProjectList(logicalPlan: LogicalPlan): Seq[Expression] = { - val topProjectList: Seq[Expression] = logicalPlan match { + + /** + * extract logicalPlan output expressions + */ + def extractTopProjectList(plan: LogicalPlan): Seq[Expression] = { + val topProjectList: Seq[Expression] = plan match { case Project(projectList, _) => projectList case Aggregate(_, aggregateExpressions, _) => aggregateExpressions case e => extractTables(Project(e.output, e))._1.output @@ -102,24 +122,120 @@ trait RewriteHelper extends PredicateHelper with RewriteLogger { topProjectList } - def extractPredictExpressions(logicalPlan: LogicalPlan, - tableMappings: BiMap[String, String]) - : (EquivalenceClasses, Seq[ExpressionEqual], Seq[ExpressionEqual]) = { + /** + * generate (alias_exprId,alias_child_expression) + */ + def generateOrigins(plan: LogicalPlan): Map[ExprId, Expression] = { + var origins = Map.empty[ExprId, Expression] + plan.transformAllExpressions { + case a@Alias(child, _) => + origins += (a.exprId -> child) + a + case e => e + } + origins + } + + /** + * find aliased_attr's original expression + */ + def findOriginExpression(plan: LogicalPlan): LogicalPlan = { + val origins = generateOrigins(plan) + findOriginExpression(origins, plan) + } + + /** + * find aliased_attr's original expression + */ + def findOriginExpression(origins: Map[ExprId, Expression], plan: LogicalPlan): LogicalPlan = { + plan.transformAllExpressions { + case a: Alias => + a.copy(child = findOriginExpression(origins, a.child))(exprId = ExprId(0), + qualifier = a.qualifier, + explicitMetadata = a.explicitMetadata, + nonInheritableMetadataKeys = a.nonInheritableMetadataKeys) + case expr => + findOriginExpression(origins, expr) + } + } + + /** + * find aliased_attr's original expression + */ + def findOriginExpression(origins: Map[ExprId, Expression], + expression: Expression): Expression = { + def dfs(expr: Expression): Expression = { + expr.transform { + case attr: AttributeReference => + if (origins.contains(attr.exprId)) { + origins(attr.exprId) + } else { + attr + } + case e => e + } + } + + dfs(expression) + } + + /** + * flag for which condition to extract + */ + val FILTER_CONDITION: Int = 1 + val INNER_JOIN_CONDITION: Int = 1 << 1 + val OUTER_JOIN_CONDITION: Int = 1 << 2 + val COMPENSABLE_CONDITION: Int = FILTER_CONDITION | INNER_JOIN_CONDITION + val ALL_JOIN_CONDITION: Int = INNER_JOIN_CONDITION | OUTER_JOIN_CONDITION + val ALL_CONDITION: Int = INNER_JOIN_CONDITION | OUTER_JOIN_CONDITION | FILTER_CONDITION + + /** + * extract condition from (join and filter), + * then transform attr's qualifier by tableMappings + */ + def extractPredictExpressions( + plan: LogicalPlan, + tableMappings: BiMap[String, String]): ( + EquivalenceClasses, Seq[ExpressionEqual], Seq[ExpressionEqual]) = { + extractPredictExpressions(plan, tableMappings, COMPENSABLE_CONDITION) + } + + /** + * extract condition from plan by flag, + * then transform attr's qualifier by tableMappings + */ + def extractPredictExpressions(plan: LogicalPlan, + tableMappings: BiMap[String, String], conditionFlag: Int): ( + EquivalenceClasses, Seq[ExpressionEqual], Seq[ExpressionEqual]) = { var conjunctivePredicates: Seq[Expression] = Seq() var equiColumnsPreds: mutable.Buffer[Expression] = ArrayBuffer() val rangePreds: mutable.Buffer[ExpressionEqual] = ArrayBuffer() val residualPreds: mutable.Buffer[ExpressionEqual] = ArrayBuffer() - val normalizedPlan = ExprSimplifier.simplify(logicalPlan) + val normalizedPlan = plan normalizedPlan foreach { case Filter(condition, _) => - conjunctivePredicates ++= splitConjunctivePredicates(condition) - case Join(_, _, _, condition, _) => - if (condition.isDefined) { - conjunctivePredicates ++= splitConjunctivePredicates(condition.get) + if ((conditionFlag & FILTER_CONDITION) > 0) { + conjunctivePredicates ++= splitConjunctivePredicates(condition) + } + case Join(_, _, joinType, condition, _) => + joinType match { + case Cross => + case Inner => + if (condition.isDefined & ((conditionFlag & INNER_JOIN_CONDITION) > 0)) { + conjunctivePredicates ++= splitConjunctivePredicates(condition.get) + } + case LeftOuter | RightOuter | FullOuter | LeftSemi | LeftAnti => + if (condition.isDefined & ((conditionFlag & OUTER_JOIN_CONDITION) > 0)) { + conjunctivePredicates ++= splitConjunctivePredicates(condition.get) + } + case _ => } case _ => } - for (e <- conjunctivePredicates) { + + val origins = generateOrigins(plan) + for (src <- conjunctivePredicates) { + val e = findOriginExpression(origins, src) if (e.isInstanceOf[EqualTo]) { val left = e.asInstanceOf[EqualTo].left val right = e.asInstanceOf[EqualTo].right @@ -141,12 +257,17 @@ trait RewriteHelper extends PredicateHelper with RewriteLogger { if ((ExprOptUtil.isReference(left, allowCast = false) && ExprOptUtil.isConstant(right)) || (ExprOptUtil.isReference(right, allowCast = false) - && ExprOptUtil.isConstant(left))) { + && ExprOptUtil.isConstant(left)) + || (left.isInstanceOf[CaseWhen] + && ExprOptUtil.isConstant(right)) + || (right.isInstanceOf[CaseWhen] + && ExprOptUtil.isConstant(left)) + ) { rangePreds += ExpressionEqual(e) } else { residualPreds += ExpressionEqual(e) } - } else if (e.isInstanceOf[Or]) { + } else if (e.isInstanceOf[Or] || e.isInstanceOf[IsNull] || e.isInstanceOf[In]) { rangePreds += ExpressionEqual(e) } else { residualPreds += ExpressionEqual(e) @@ -162,22 +283,29 @@ trait RewriteHelper extends PredicateHelper with RewriteLogger { (equivalenceClasses, rangePreds, residualPreds) } - def extractTables(logicalPlan: LogicalPlan): (LogicalPlan, Set[TableEqual]) = { + /** + * extract used tables from logicalPlan + * and fill attr's qualifier + * + * @return (used tables,filled qualifier plan) + */ + def extractTables(plan: LogicalPlan): (LogicalPlan, Set[TableEqual]) = { // tableName->duplicateIndex,start from 0 val qualifierToIdx = mutable.HashMap.empty[String, Int] // logicalPlan->(tableName,duplicateIndex) - val tablePlanToIdx = mutable.HashMap.empty[LogicalPlan, (String, Int, String)] + val tablePlanToIdx = mutable.HashMap.empty[LogicalPlan, (String, Int, String, Long)] // exprId->AttributeReference,use this to replace LogicalPlan's attr val exprIdToAttr = mutable.HashMap.empty[ExprId, AttributeReference] val addIdxAndAttrInfo = (catalogTable: CatalogTable, logicalPlan: LogicalPlan, - attrs: Seq[AttributeReference]) => { + attrs: Seq[AttributeReference], seq: Long) => { val table = catalogTable.identifier.toString() val idx = qualifierToIdx.getOrElse(table, -1) + 1 qualifierToIdx += (table -> idx) tablePlanToIdx += (logicalPlan -> (table, idx, Seq(SESSION_CATALOG_NAME, catalogTable.database, - catalogTable.identifier.table, String.valueOf(idx)).mkString("."))) + catalogTable.identifier.table, String.valueOf(idx)).mkString("."), + seq)) attrs.foreach { attr => val newAttr = attr.copy()(exprId = attr.exprId, qualifier = Seq(SESSION_CATALOG_NAME, catalogTable.database, @@ -186,25 +314,80 @@ trait RewriteHelper extends PredicateHelper with RewriteLogger { } } - logicalPlan.foreachUp { + var seq = 0L + plan.foreachUp { case h@HiveTableRelation(tableMeta, _, _, _, _) => - addIdxAndAttrInfo(tableMeta, h, h.output) + seq += 1 + addIdxAndAttrInfo(tableMeta, h, h.output, seq) case h@LogicalRelation(_, _, catalogTable, _) => + seq += 1 if (catalogTable.isDefined) { - addIdxAndAttrInfo(catalogTable.get, h, h.output) + addIdxAndAttrInfo(catalogTable.get, h, h.output, seq) } case _ => } + plan.transformAllExpressions { + case a@Alias(child, name) => + child match { + case attr: AttributeReference => + if (exprIdToAttr.contains(attr.exprId)) { + val d = exprIdToAttr(attr.exprId) + exprIdToAttr += (a.exprId -> d + .copy(name = name)(exprId = a.exprId, qualifier = d.qualifier)) + } + case _ => + } + a + case e => e + } + val mappedTables = tablePlanToIdx.keySet.map { tablePlan => - val (tableName, idx, qualifier) = tablePlanToIdx(tablePlan) + val (tableName, idx, qualifier, seq) = tablePlanToIdx(tablePlan) TableEqual(tableName, "%s.%d".format(tableName, idx), - qualifier, fillQualifier(tablePlan, exprIdToAttr)) + qualifier, fillQualifier(tablePlan, exprIdToAttr), seq) }.toSet - val mappedQuery = fillQualifier(logicalPlan, exprIdToAttr) + val mappedQuery = fillQualifier(plan, exprIdToAttr) (mappedQuery, mappedTables) } + /** + * extract used CatalogTables from logicalPlan + * + * @return used CatalogTables + */ + def extractCatalogTablesOnly(plan: LogicalPlan): Set[CatalogTable] = { + var tables = mutable.Seq[CatalogTable]() + plan.foreachUp { + case HiveTableRelation(tableMeta, _, _, _, _) => + tables +:= tableMeta + case LogicalRelation(_, _, catalogTable, _) => + if (catalogTable.isDefined) { + tables +:= catalogTable.get + } + case p => + p.transformAllExpressions { + case e: SubqueryExpression => + tables ++= extractCatalogTablesOnly(e.plan) + e + case e => e + } + } + tables.toSet + } + + /** + * extract used tables from logicalPlan + * + * @return used tables + */ + def extractTablesOnly(plan: LogicalPlan): Set[String] = { + extractCatalogTablesOnly(plan).map(_.identifier.toString()) + } + + /** + * transform plan's attr by tableMapping then columnMapping + */ def swapTableColumnReferences[T <: Iterable[Expression]](expressions: T, tableMapping: BiMap[String, String], columnMapping: Map[ExpressionEqual, @@ -244,6 +427,9 @@ trait RewriteHelper extends PredicateHelper with RewriteLogger { result } + /** + * transform plan's attr by columnMapping then tableMapping + */ def swapColumnTableReferences[T <: Iterable[Expression]](expressions: T, tableMapping: BiMap[String, String], columnMapping: Map[ExpressionEqual, @@ -253,19 +439,150 @@ trait RewriteHelper extends PredicateHelper with RewriteLogger { result } + /** + * transform plan's attr by tableMapping + */ def swapTableReferences[T <: Iterable[Expression]](expressions: T, tableMapping: BiMap[String, String]): T = { swapTableColumnReferences(expressions, tableMapping, EMPTY_MAP) } + /** + * transform plan's attr by columnMapping + */ def swapColumnReferences[T <: Iterable[Expression]](expressions: T, columnMapping: Map[ExpressionEqual, mutable.Set[ExpressionEqual]]): T = { swapTableColumnReferences(expressions, EMPTY_BIMAP, columnMapping) } + + /** + * generate string for simplifiedPlan + * + * @param plan plan + * @param jt joinType + * @return string for simplifiedPlan + */ + def simplifiedPlanString(plan: LogicalPlan, jt: Int): String = { + val EMPTY_STRING = "" + RewriteHelper.canonicalize(ExprSimplifier.simplify(plan)).collect { + case Join(_, _, joinType, condition, hint) => + joinType match { + case Inner => + if ((INNER_JOIN_CONDITION & jt) > 0) { + joinType.toString + condition.getOrElse(Literal.TrueLiteral).sql + hint.toString() + } else { + EMPTY_STRING + } + case LeftOuter | RightOuter | FullOuter | LeftSemi | LeftAnti => + if ((OUTER_JOIN_CONDITION & jt) > 0) { + joinType.toString + condition.getOrElse(Literal.TrueLiteral).sql + hint.toString() + } else { + EMPTY_STRING + } + case _ => + EMPTY_STRING + } + case Filter(condition: Expression, _) => + if ((FILTER_CONDITION & jt) > 0) { + condition.sql + } else { + EMPTY_STRING + } + case HiveTableRelation(tableMeta, _, _, _, _) => + tableMeta.identifier.toString() + case LogicalRelation(_, _, catalogTable, _) => + if (catalogTable.isDefined) { + catalogTable.get.identifier.toString() + } else { + EMPTY_STRING + } + case _ => + EMPTY_STRING + }.mkString(EMPTY_STRING) + } + + /** + * check attr in viewTableAttrs + * + * @param expression expression + * @param viewTableAttrs viewTableAttrs + * @return true:in ;false:not in + */ + def isValidExpression(expression: Expression, viewTableAttrs: Set[Attribute]): Boolean = { + expression.foreach { + case attr: AttributeReference => + if (!viewTableAttrs.contains(attr)) { + return false + } + case _ => + } + true + } + + /** + * partitioned mv columns differ to mv query projectList, sort mv query projectList + */ + def sortProjectListForPartition(plan: LogicalPlan, catalogTable: CatalogTable): LogicalPlan = { + if (catalogTable.partitionColumnNames.isEmpty) { + return plan + } + val partitionColumnNames = catalogTable.partitionColumnNames.toSet + plan match { + case Project(projectList, child) => + var newProjectList = projectList.filter(x => !partitionColumnNames.contains(x.name)) + val projectMap = projectList.map(x => (x.name, x)).toMap + newProjectList = newProjectList ++ partitionColumnNames.map(x => projectMap(x)) + Project(newProjectList, child) + case Aggregate(groupingExpressions, aggregateExpressions, child) => + var newProjectList = aggregateExpressions + .filter(x => !partitionColumnNames.contains(x.name)) + val projectMap = aggregateExpressions.map(x => (x.name, x)).toMap + newProjectList = newProjectList ++ partitionColumnNames.map(x => projectMap(x)) + Aggregate(groupingExpressions, newProjectList, child) + case p => p + } + } + + /** + * use all tables to fetch views(may match) from ViewMetaData + * + * @param tableNames tableNames in query sql + * @return Seq[(viewName, viewTablePlan, viewQueryPlan)] + */ + def getApplicableMaterializations(tableNames: Set[String]): Seq[ViewMetadataPackageType] = { + // viewName, viewTablePlan, viewQueryPlan + var viewPlans = Seq.empty[(String, LogicalPlan, LogicalPlan)] + + ViewMetadata.viewToContainsTables.forEach { (viewName, tableEquals) => + // 1.add plan info + if (tableEquals.map(_.tableName).subsetOf(tableNames)) { + val viewQueryPlan = ViewMetadata.viewToViewQueryPlan.get(viewName) + val viewTablePlan = ViewMetadata.viewToTablePlan.get(viewName) + viewPlans +:= (viewName, viewTablePlan, viewQueryPlan) + } + } + resortMaterializations(viewPlans) + } + + /** + * resort materializations by priority + */ + def resortMaterializations(candidateViewPlans: Seq[(String, + LogicalPlan, LogicalPlan)]): Seq[(String, LogicalPlan, LogicalPlan)] = { + val tuples = candidateViewPlans.sortWith((c1, c2) => + ViewMetadata.viewPriority.getOrDefault(c1._1, 0) > + ViewMetadata.viewPriority.getOrDefault(c2._1, 0) + ) + tuples + } } object RewriteHelper extends PredicateHelper with RewriteLogger { + + private val secondsInAYear = 31536000L + private val daysInTenYear = 3650 + /** * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be * equivalent: @@ -328,8 +645,28 @@ object RewriteHelper extends PredicateHelper with RewriteLogger { } def canonicalize(expression: Expression): Expression = { - val canonicalizedChildren = expression.children.map(RewriteHelper.canonicalize) - expressionReorder(expression.withNewChildren(canonicalizedChildren)) + RewriteTime.withTimeStat("canonicalize") { + val canonicalizedChildren = expression.children.map(RewriteHelper.canonicalize) + expressionReorder(expression.withNewChildren(canonicalizedChildren)) + } + } + + def canonicalize(plan: LogicalPlan): LogicalPlan = { + RewriteTime.withTimeStat("canonicalize") { + plan transform { + case filter@Filter(condition: Expression, child: LogicalPlan) => + filter.copy(canonicalize(condition), child) + case join@Join(left: LogicalPlan, right: LogicalPlan, joinType: JoinType, + condition: Option[Expression], hint: JoinHint) => + if (condition.isDefined) { + join.copy(left, right, joinType, Option(canonicalize(condition.get)), hint) + } else { + join + } + case e => + e + } + } } /** Collects adjacent commutative operations. */ @@ -396,6 +733,9 @@ object RewriteHelper extends PredicateHelper with RewriteLogger { case _ => e } + /** + * extract all attrs used in expressions + */ def extractAllAttrsFromExpression(expressions: Seq[Expression]): Set[AttributeReference] = { var attrs = Set[AttributeReference]() expressions.foreach { e => @@ -408,15 +748,18 @@ object RewriteHelper extends PredicateHelper with RewriteLogger { attrs } - def containsMV(logicalPlan: LogicalPlan): Boolean = { - logicalPlan.foreachUp { + /** + * check if logicalPlan use mv + */ + def containsMV(plan: LogicalPlan): Boolean = { + plan.foreachUp { case _@HiveTableRelation(tableMeta, _, _, _, _) => - if (OmniCachePluginConfig.isMV(tableMeta)) { + if (OmniMVPluginConfig.isMV(tableMeta)) { return true } case _@LogicalRelation(_, _, catalogTable, _) => if (catalogTable.isDefined) { - if (OmniCachePluginConfig.isMV(catalogTable.get)) { + if (OmniMVPluginConfig.isMV(catalogTable.get)) { return true } } @@ -426,16 +769,31 @@ object RewriteHelper extends PredicateHelper with RewriteLogger { } def enableCachePlugin(): Unit = { - SQLConf.get.setConfString("spark.sql.omnicache.enable", "true") + SQLConf.get.setConfString("spark.sql.omnimv.enable", "true") + SQLConf.get.setConfString("spark.sql.omnimv.log.enable", "true") } def disableCachePlugin(): Unit = { - SQLConf.get.setConfString("spark.sql.omnicache.enable", "false") + SQLConf.get.setConfString("spark.sql.omnimv.enable", "false") + SQLConf.get.setConfString("spark.sql.omnimv.log.enable", "false") } - def checkAttrsValid(logicalPlan: LogicalPlan): Boolean = { - logicalPlan.foreachUp { + def enableSqlLog(): Unit = { + SQLConf.get.setConfString("spark.sql.omnimv.log.enable", "true") + } + + def disableSqlLog(): Unit = { + SQLConf.get.setConfString("spark.sql.omnimv.log.enable", "false") + } + + /** + * check if plan's input attrs satisfy used attrs + */ + def checkAttrsValid(plan: LogicalPlan): Boolean = { + logDetail(s"checkAttrsValid for plan:$plan") + plan.foreachUp { case _: LeafNode => + case _: Expand => case plan => val attributeSets = plan.expressions.map { expression => AttributeSet.fromAttributeSets( @@ -462,6 +820,53 @@ object RewriteHelper extends PredicateHelper with RewriteLogger { } true } + + /** + * use rules to optimize queryPlan and viewQueryPlan + */ + def optimizePlan(plan: LogicalPlan): LogicalPlan = { + val rules: Seq[Rule[LogicalPlan]] = Seq( + SimplifyCasts, ConstantFolding, UnwrapCastInBinaryComparison, ColumnPruning) + var res = plan + RewriteTime.withTimeStat("optimizePlan") { + rules.foreach { rule => + res = rule.apply(res) + } + } + res + } + + def getMVDatabase(MVTablePlan: LogicalPlan): Option[String] = { + MVTablePlan.foreach { + case _@HiveTableRelation(tableMeta, _, _, _, _) => + return Some(tableMeta.database) + case _@LogicalRelation(_, _, catalogTable, _) => + if (catalogTable.isDefined) { + return Some(catalogTable.get.database) + } + case _: LocalRelation => + case _ => + } + None + } + + def daysToMillisecond(days: Long): Long = { + if (days > daysInTenYear || days < 0) { + throw new IllegalArgumentException( + "The day time cannot be less than 0 days" + + " or exceed 3650 days.") + } + days * 24 * 60 * 60 * 1000 + } + + def secondsToMillisecond(seconds: Long): Long = { + if (seconds > secondsInAYear || seconds < 0L) { + throw new IllegalArgumentException( + "The second time cannot be less than 0 seconds" + + " or exceed 31536000 seconds.") + } + seconds * 1000 + } } case class ExpressionEqual(expression: Expression) { @@ -476,14 +881,19 @@ case class ExpressionEqual(expression: Expression) { override def hashCode(): Int = sql.hashCode() - def extractRealExpr(expression: Expression): Expression = expression match { - case Alias(child, _) => extractRealExpr(child) - case other => other + def extractRealExpr(expression: Expression): Expression = { + expression.transform { + case Alias(child, _) => child + case Cast(child, _, _) => child + case other => other + } } + + override def toString: String = s"ExpressionEqual($sql)" } case class TableEqual(tableName: String, tableNameWithIdx: String, - qualifier: String, logicalPlan: LogicalPlan) { + qualifier: String, logicalPlan: LogicalPlan, seq: Long) { override def equals(obj: Any): Boolean = obj match { case other: TableEqual => tableNameWithIdx == other.tableNameWithIdx @@ -492,3 +902,29 @@ case class TableEqual(tableName: String, tableNameWithIdx: String, override def hashCode(): Int = tableNameWithIdx.hashCode() } + +case class AttributeReferenceEqual(attr: AttributeReference) { + override def toString: String = attr.sql + + override def equals(obj: Any): Boolean = obj match { + case attrEqual: AttributeReferenceEqual => + attr.name == attrEqual.attr.name && attr.dataType == attrEqual.attr.dataType && + attr.nullable == attrEqual.attr.nullable && attr.metadata == attrEqual.attr.metadata && + attr.qualifier == attrEqual.attr.qualifier + // case attribute: AttributeReference => + // attr.name == attribute.name && attr.dataType == attribute.dataType && + // attr.nullable == attribute.nullable && attr.metadata == attribute.metadata && + // attr.qualifier == attribute.qualifier + case _ => false + } + + override def hashCode(): Int = { + var h = 17 + h = h * 37 + attr.name.hashCode() + h = h * 37 + attr.dataType.hashCode() + h = h * 37 + attr.nullable.hashCode() + h = h * 37 + attr.metadata.hashCode() + h = h * 37 + attr.qualifier.hashCode() + h + } +} diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/RewriteLogger.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/RewriteLogger.scala similarity index 85% rename from omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/RewriteLogger.scala rename to omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/RewriteLogger.scala index 6cd88c7d8d28bf09f6944deaa5aee5817efcbd67..545aae2fa60811600f4fee3ab4eb3b8e26263012 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/RewriteLogger.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/RewriteLogger.scala @@ -17,15 +17,15 @@ package com.huawei.boostkit.spark.util -import com.huawei.boostkit.spark.conf.OmniCachePluginConfig +import com.huawei.boostkit.spark.conf.OmniMVPluginConfig import org.apache.spark.internal.Logging trait RewriteLogger extends Logging { - private def logLevel: String = OmniCachePluginConfig.getConf.logLevel + private def logLevel: String = OmniMVPluginConfig.getConf.logLevel - private val logFlag = "[OmniCache]" + private val logFlag = "[OmniMV]" def logBasedOnLevel(f: => String): Unit = { logLevel match { @@ -38,6 +38,14 @@ trait RewriteLogger extends Logging { } } + def logDetail(f: => String): Unit = { + logLevel match { + case "ERROR" => + logWarning(f) + case _ => + } + } + override def logInfo(msg: => String): Unit = { super.logInfo(s"$logFlag $msg") } diff --git a/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ViewMetadata.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ViewMetadata.scala new file mode 100644 index 0000000000000000000000000000000000000000..99d2c1afa63753995a3293d45a445d5e2e2b9fe8 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/ViewMetadata.scala @@ -0,0 +1,956 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.util + +import com.huawei.boostkit.spark.conf.OmniMVPluginConfig +import com.huawei.boostkit.spark.conf.OmniMVPluginConfig._ +import com.huawei.boostkit.spark.util.serde.KryoSerDeUtil +import java.io.IOException +import java.net.URI +import java.util.Locale +import java.util.concurrent.{ConcurrentHashMap, Executors, TimeUnit} +import java.util.concurrent.atomic.AtomicLong +import org.apache.commons.io.IOUtils +import org.apache.hadoop.fs.{FileStatus, FileSystem, LocalFileSystem, Path} +import org.json4s.DefaultFormats +import org.json4s.jackson.Json +import scala.collection.{mutable, JavaConverters} + +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExprId, NamedExpression} +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteTime +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.datasources.LogicalRelation + + +object ViewMetadata extends RewriteHelper { + + val viewToViewQueryPlan = new ConcurrentHashMap[String, LogicalPlan]() + + val viewToTablePlan = new ConcurrentHashMap[String, LogicalPlan]() + + val viewToContainsTables = new ConcurrentHashMap[String, Set[TableEqual]]() + + val tableToViews = new ConcurrentHashMap[String, mutable.Set[String]]() + + val viewProperties = new ConcurrentHashMap[String, Map[String, String]]() + + val viewPriority = new ConcurrentHashMap[String, Long]() + + // Map (viewName <- Array(viewCounts, lastUsedMillisecond, fileModifyTime)) + val viewCnt = new ConcurrentHashMap[String, Array[Long]]() + + var spark: SparkSession = _ + + var fs: FileSystem = _ + + var metadataPath: Path = _ + var metadataPriorityPath: Path = _ + + var initQueryPlan: Option[LogicalPlan] = None + + var washOutTimestamp: Option[Long] = Option.empty + var autoWashOutTimestamp: Option[Long] = Option.empty + + val STATUS_UN_LOAD = "UN_LOAD" + val STATUS_LOADING = "LOADING" + val STATUS_LOADED = "LOADED" + + var status: String = STATUS_UN_LOAD + + val VIEW_CNT_FILE = "viewCount" + val VIEW_CNT_FILE_LOCK = "viewCount.lock" + val DEFAULT_DATABASE = "default" + val VIEW_CONTAINS_TABLES_FILE = "viewContainsTables" + val WASH_OUT_TIMESTAMP = "washOutTimestamp" + + private var kryoSerializer: KryoSerializer = _ + + private val SEPARATOR: Char = 0xA + + val UNLOAD: Int = -1 + + private var REFRESH_STAT: String = _ + private val BUSY = "BUSY" + private val IDLE = "IDLE" + + /** + * set sparkSession + */ + def setSpark(sparkSession: SparkSession): Unit = { + spark = sparkSession + REFRESH_STAT = IDLE + status = STATUS_LOADING + + kryoSerializer = new KryoSerializer(spark.sparkContext.getConf) + + metadataPath = new Path(OmniMVPluginConfig.getConf.metadataPath) + metadataPriorityPath = new Path(metadataPath, "priority") + + val conf = KerberosUtil.newConfiguration(spark) + fs = metadataPath.getFileSystem(conf) + + val paths = Seq(metadataPath, metadataPriorityPath) + paths.foreach { path => + if (!fs.exists(path)) { + fs.mkdirs(path) + } + } + } + + /** + * save mv metadata to cache + */ + def saveViewMetadataToMap(catalogTable: CatalogTable): Unit = this.synchronized { + val viewQuerySql = catalogTable.properties.getOrElse(MV_QUERY_ORIGINAL_SQL, "") + if (viewQuerySql.isEmpty) { + logError(s"mvTable: ${catalogTable.identifier.quotedString}'s viewQuerySql is empty!") + return + } + + // preserve preDatabase and set curDatabase + val preDatabase = spark.catalog.currentDatabase + val curDatabase = catalogTable.properties.getOrElse(MV_QUERY_ORIGINAL_SQL_CUR_DB, "") + if (curDatabase.isEmpty) { + logError(s"mvTable: ${catalogTable.identifier.quotedString}'s curDatabase is empty!") + return + } + try { + spark.sessionState.catalogManager.setCurrentNamespace(Array(curDatabase)) + + // db.table + val tableName = catalogTable.identifier.quotedString + val viewTablePlan = RewriteTime + .withTimeStat("viewTablePlan") { + spark.table(tableName).queryExecution.analyzed match { + case SubqueryAlias(_, child) => child + case a@_ => a + } + } + var viewQueryPlan = RewriteTime + .withTimeStat("viewQueryPlan") { + RewriteHelper.optimizePlan( + spark.sql(viewQuerySql).queryExecution.analyzed) + } + viewQueryPlan = viewQueryPlan match { + case RepartitionByExpression(_, child, _) => + child + case _ => + viewQueryPlan + } + // reset preDatabase + spark.sessionState.catalogManager.setCurrentNamespace(Array(preDatabase)) + + // spark_catalog.db.table + val viewName = formatViewName(catalogTable.identifier) + + // mappedViewQueryPlan and mappedViewContainsTables + val (mappedViewQueryPlan, mappedViewContainsTables) = RewriteTime + .withTimeStat("extractTables") { + extractTables(sortProjectListForPartition(viewQueryPlan, catalogTable)) + } + + mappedViewContainsTables + .foreach { mappedViewContainsTable => + val name = mappedViewContainsTable.tableName + val views = tableToViews.getOrDefault(name, mutable.Set.empty) + views += viewName + tableToViews.put(name, views) + } + + // extract view query project's Attr and replace view table's Attr by query project's Attr + // match function is attributeReferenceEqualSimple, by name and data type + // Attr of table cannot used, because same Attr in view query and view table, + // it's table is different. + val mappedViewTablePlan = RewriteTime + .withTimeStat("mapTablePlanAttrToQuery") { + mapTablePlanAttrToQuery(viewTablePlan, mappedViewQueryPlan) + } + + viewToContainsTables.put(viewName, mappedViewContainsTables) + viewToViewQueryPlan.putIfAbsent(viewName, mappedViewQueryPlan) + viewToTablePlan.putIfAbsent(viewName, mappedViewTablePlan) + viewProperties.put(viewName, catalogTable.properties) + saveViewMetadataToFile(catalogTable.database, viewName) + } catch { + case e: Throwable => + logDebug(s"Failed to saveViewMetadataToMap,errmsg: ${e.getMessage}") + throw new IOException(s"Failed to save ViewMetadata to file.") + } finally { + // reset preDatabase + spark.sessionState.catalogManager.setCurrentNamespace(Array(preDatabase)) + } + } + + /** + * is metadata empty + */ + def isEmpty: Boolean = { + viewToTablePlan.isEmpty + } + + /** + * is mv exists + */ + def isViewExists(viewIdentifier: String): Boolean = { + viewToTablePlan.containsKey(viewIdentifier) + } + + /** + * add catalog table to cache + */ + def addCatalogTableToCache(table: CatalogTable): Unit = this.synchronized { + saveViewMetadataToMap(table) + if (!isViewEnable(table.properties)) { + removeMVCache(table.identifier) + } + } + + /** + * remove mv metadata from cache + */ + def removeMVCache(tableName: TableIdentifier): Unit = this.synchronized { + val viewName = formatViewName(tableName) + viewToContainsTables.remove(viewName) + viewToViewQueryPlan.remove(viewName) + viewToTablePlan.remove(viewName) + viewProperties.remove(viewName) + tableToViews.forEach { (key, value) => + if (value.contains(viewName)) { + value -= viewName + tableToViews.put(key, value) + } + } + } + + /** + * init mv metadata + */ + def init(sparkSession: SparkSession): Unit = { + init(sparkSession, None) + } + + /** + * init mv metadata with certain queryPlan + */ + def init(sparkSession: SparkSession, queryPlan: Option[LogicalPlan]): Unit = { + if (status == STATUS_LOADED) { + return + } + + initQueryPlan = queryPlan + setSpark(sparkSession) + forceLoad() + status = STATUS_LOADED + } + + // Called when ViewMetadata is initialized. + def forceLoad(): Unit = this.synchronized { + loadViewContainsTablesFromFile() + loadViewMetadataFromFile() + loadViewPriorityFromFile() + loadViewCount() + checkViewMetadataComplete() + } + + /** + * load mv metadata from metastore + */ + def forceLoadFromMetastore(): Unit = this.synchronized { + val catalog = spark.sessionState.catalog + + // load from all db + val dbs = RewriteTime.withTimeStat("loadDbs") { + if (getConf.omniMVDB.nonEmpty) { + getConf.omniMVDB.split(",").toSeq + } else { + catalog.listDatabases() + } + } + for (db <- dbs) { + val tables = RewriteTime.withTimeStat(s"loadTable from $db") { + omniMVFilter(catalog, db) + } + RewriteTime.withTimeStat("saveViewMetadataToMap") { + tables.foreach(tableData => saveViewMetadataToMap(tableData)) + } + } + logDetail(s"tableToViews:$tableToViews") + } + + /** + * filter mv metadata from database + */ + def omniMVFilter(catalog: SessionCatalog, + mvDataBase: String): Seq[CatalogTable] = { + var res: Seq[CatalogTable] = Seq.empty[CatalogTable] + try { + val allTables = catalog.listTables(mvDataBase) + res = catalog.getTablesByName(allTables).filter { tableData => + tableData.properties.contains(MV_QUERY_ORIGINAL_SQL) + } + } catch { + // if db exists a table hive materialized view, will throw analysis exception + case e: Throwable => + logDebug(s"Failed to listTables in $mvDataBase, errmsg: ${e.getMessage}") + throw new UnsupportedOperationException("hive materialized view is not supported.") + } + res + } + + /** + * offset expression's exprId + * origin exprId + NamedExpression.newExprId.id + */ + def offsetExprId(plan: LogicalPlan): LogicalPlan = { + val offset = NamedExpression.newExprId.id + var maxId = offset + val res = plan.transformAllExpressions { + case alias: Alias => + val id = offset + alias.exprId.id + maxId = Math.max(maxId, id) + alias.copy()(exprId = alias.exprId.copy(id = id), qualifier = alias.qualifier, + explicitMetadata = alias.explicitMetadata, + nonInheritableMetadataKeys = alias.nonInheritableMetadataKeys) + case attr: AttributeReference => + val id = offset + attr.exprId.id + maxId = Math.max(maxId, id) + attr.copy()(exprId = attr.exprId.copy(id = id), qualifier = attr.qualifier) + case e => e + } + val idField = NamedExpression.getClass.getDeclaredField("curId") + idField.setAccessible(true) + val id = idField.get(NamedExpression).asInstanceOf[AtomicLong] + id.set(maxId) + while (NamedExpression.newExprId.id <= maxId) {} + res + } + + /** + * reassign exprId from 0 before save to file + */ + def reassignExprId(plan: LogicalPlan): LogicalPlan = { + val idMappings = mutable.HashMap[Long, Long]() + var start = 0 + + def mappingId(exprId: ExprId): Long = { + val id = if (idMappings.contains(exprId.id)) { + idMappings(exprId.id) + } else { + start += 1 + idMappings += (exprId.id -> start) + start + } + id + } + + plan.transformAllExpressions { + case alias: Alias => + val id = mappingId(alias.exprId) + alias.copy()(exprId = alias.exprId.copy(id = id), qualifier = alias.qualifier, + explicitMetadata = alias.explicitMetadata, + nonInheritableMetadataKeys = alias.nonInheritableMetadataKeys) + case attr: AttributeReference => + val id = mappingId(attr.exprId) + attr.copy()(exprId = attr.exprId.copy(id = id), qualifier = attr.qualifier) + case e => e + } + } + + /** + * save mv metadata to file + */ + def saveViewMetadataToFile(kryoSerializer: KryoSerializer, dbName: String, + viewName: String): Unit = { + val tablePlan = reassignExprId(viewToTablePlan.get(viewName)) + val queryPlan = reassignExprId(viewToViewQueryPlan.get(viewName)) + val properties = viewProperties.get(viewName) + + val jsons = mutable.Map[String, String]() + + val tablePlanStr = KryoSerDeUtil.serializePlan(kryoSerializer, tablePlan) + jsons += ("tablePlan" -> tablePlanStr) + + val queryPlanStr = KryoSerDeUtil.serializePlan(kryoSerializer, queryPlan) + jsons += ("queryPlan" -> queryPlanStr) + + val propertiesStr = KryoSerDeUtil.serializeToStr(kryoSerializer, properties) + jsons += ("properties" -> propertiesStr) + + jsons += (MV_REWRITE_ENABLED -> properties(MV_REWRITE_ENABLED)) + + saveMapToDisk(dbName, viewName, jsons, isAppend = false, lineFeed = false) + } + + /** + * save mv metadata to file + */ + def saveViewMetadataToFile(dbName: String, viewName: String): Unit = { + saveViewMetadataToFile(kryoSerializer, dbName, viewName) + saveViewContainsTablesToFile(dbName, viewName) + } + + /** + * save view contains tables to file + */ + def saveViewContainsTablesToFile(dbName: String, viewName: String): Unit = { + val data = loadViewContainsTablesFromFile(dbName) + data.put(viewName, (viewToContainsTables.get(viewName).map(_.tableName), + System.currentTimeMillis())) + saveMapToDisk(dbName, VIEW_CONTAINS_TABLES_FILE, data, isAppend = true, lineFeed = true) + } + + /** + * load view contains tables to file + */ + def loadViewContainsTablesFromFile(): mutable.Map[String, (Set[String], Long)] = { + val dbs = getDBs + + val jsons = mutable.Map[String, (Set[String], Long)]().empty + dbs.foreach { db => + val properties = loadViewContainsTablesFromFile(db) + for ((view, (tables, time)) <- properties) { + if (!jsons.contains(view) || jsons(view)._2 < time) { + jsons += (view -> (tables, time)) + } + } + } + jsons + } + + /** + * load view contains tables to file + */ + def loadViewContainsTablesFromFile(dbName: String): mutable.Map[String, (Set[String], Long)] = { + val jsons = mutable.Map[String, (Set[String], Long)]().empty + loadDataFromDisk(dbName, VIEW_CONTAINS_TABLES_FILE, isTailLines = true, jsons) { + (preData, curData, modifyTime) => + for ((view, (tables, time)) <- curData) { + if (!preData.contains(view) || preData(view)._2 < time) { + preData += (view -> (tables, time)) + } + } + } + } + + /** + * load view priority from file + */ + def loadViewPriorityFromFile(): Unit = { + fs.listStatus(metadataPriorityPath) + .sortWith((f1, f2) => f1.getModificationTime < f2.getModificationTime) + .foreach { file => + val is = fs.open(file.getPath) + val lines = JavaConverters + .asScalaIteratorConverter( + IOUtils.readLines(is, "UTF-8").iterator()).asScala.toSeq + is.close() + lines.foreach { line => + val views = line.split(",") + var len = views.length + views.foreach { view => + viewPriority.put(view, len) + len -= 1 + } + } + } + } + + /** + * load metadata file when mv's db=omniMVDB and mv exists + * and when enableMetadataInitByQuery only load relate with query + */ + def filterValidMetadata(): Array[FileStatus] = { + val files = fs.listStatus(metadataPath).flatMap(x => fs.listStatus(x.getPath)) + val dbs = getDBs + val dbTables = mutable.Set.empty[String] + dbs.foreach { db => + if (spark.sessionState.catalog.databaseExists(db)) { + dbTables ++= spark.sessionState.catalog.listTables(db).map(formatViewName) + } + } + var res = files.filter { file => + dbTables.contains(file.getPath.getName) + } + + if (OmniMVPluginConfig.getConf.enableMetadataInitByQuery && initQueryPlan.isDefined) { + RewriteTime.withTimeStat("loadViewContainsTablesFromFile") { + val queryTables = extractTablesOnly(initQueryPlan.get) + val viewContainsTables = loadViewContainsTablesFromFile() + res = res.filter { file => + val view = file.getPath.getName + viewContainsTables.contains(view) && viewContainsTables(view)._1.subsetOf(queryTables) + } + } + } + + res + } + + + /** + * load mv metadata from file + */ + def loadViewMetadataFromFile(): Unit = { + if (!fs.exists(metadataPath)) { + return + } + + val files = RewriteTime.withTimeStat("listStatus") { + filterValidMetadata() + } + + val threadPool = RewriteTime.withTimeStat("threadPool") { + Executors.newFixedThreadPool(Math.max(50, files.length * 2)) + } + + files.foreach { file => + threadPool.submit { + new Runnable { + override def run(): Unit = { + val viewName = file.getPath.getName + val is = fs.open(file.getPath) + val jsons: Map[String, String] = RewriteTime.withTimeStat("Json.read.C") { + Json(DefaultFormats).read[Map[String, String]](is) + } + is.close() + + if (!isViewEnable(jsons)) { + return + } + + val tablePlanStr = jsons("tablePlan") + val tablePlan = RewriteTime.withTimeStat("deSerTablePlan.C") { + KryoSerDeUtil.deserializePlan(kryoSerializer, spark, tablePlanStr) + } + viewToTablePlan.put(viewName, tablePlan) + + val propertiesStr = jsons("properties") + val properties = RewriteTime.withTimeStat("deSerProperties.C") { + KryoSerDeUtil.deserializeFromStr[Map[String, String]](kryoSerializer, propertiesStr) + } + viewProperties.put(viewName, properties) + } + } + } + + threadPool.submit { + new Runnable { + override def run(): Unit = { + val viewName = file.getPath.getName + val is = fs.open(file.getPath) + val jsons: Map[String, String] = RewriteTime.withTimeStat("Json.read.C") { + Json(DefaultFormats).read[Map[String, String]](is) + } + is.close() + + if (!isViewEnable(jsons)) { + return + } + + val queryPlanStr = jsons("queryPlan") + val queryPlan = RewriteTime.withTimeStat("deSerQueryPlan.C") { + KryoSerDeUtil.deserializePlan(kryoSerializer, spark, queryPlanStr) + } + viewToViewQueryPlan.put(viewName, queryPlan) + } + } + } + } + + threadPool.shutdown() + threadPool.awaitTermination(20, TimeUnit.SECONDS) + + viewProperties.keySet().forEach { viewName => + val tablePlan = viewToTablePlan.get(viewName) + val queryPlan = viewToViewQueryPlan.get(viewName) + + val resignTablePlan = RewriteTime.withTimeStat("reSignExprId") { + offsetExprId(tablePlan) + } + viewToTablePlan.put(viewName, resignTablePlan) + + val resignQueryPlan = RewriteTime.withTimeStat("reSignExprId") { + offsetExprId(queryPlan) + } + viewToViewQueryPlan.put(viewName, resignQueryPlan) + + val (_, tables) = RewriteTime.withTimeStat("extractTables") { + extractTables(resignQueryPlan) + } + viewToContainsTables.put(viewName, tables) + + RewriteTime.withTimeStat("tableToViews") { + tables.foreach { table => + val name = table.tableName + val views = tableToViews.getOrDefault(name, mutable.Set.empty) + views += viewName + tableToViews.put(name, views) + } + } + } + } + + /** + * delete mv metadata from file + */ + def deleteViewMetadata(identifier: TableIdentifier): Unit = { + removeMVCache(identifier) + val viewName = formatViewName(identifier) + fs.delete(new Path(new Path(metadataPath, identifier.database.get), viewName), true) + } + + /** + * formatted mv name + */ + def formatViewName(identifier: TableIdentifier): String = { + identifier.toString().replace("`", "").toLowerCase(Locale.ROOT) + } + + /** + * is mv enable rewrite + */ + def isViewEnable(jsons: Map[String, String]): Boolean = { + jsons.contains(MV_REWRITE_ENABLED) && jsons(MV_REWRITE_ENABLED).toBoolean + } + + /** + * check mv metadata load complete + */ + def checkViewMetadataComplete(): Unit = { + val loadSize = viewToViewQueryPlan.size() + var checkRes = true + checkRes &&= (loadSize == viewToTablePlan.size()) + checkRes &&= (loadSize == viewToContainsTables.size()) + checkRes &&= (loadSize == viewProperties.size()) + if (!checkRes) { + viewToViewQueryPlan.clear() + viewToTablePlan.clear() + viewToContainsTables.clear() + viewProperties.clear() + tableToViews.clear() + viewProperties.clear() + } + } + + // Called when apply a MV rewrite. + def saveViewCountToFile(): Unit = { + val dbs = mutable.Set[String]() + ViewMetadata.viewCnt.forEach { + (name, _) => + dbs.add(name.split("\\.")(0)) + } + for (db <- dbs) { + saveViewCountToFile(db) + } + } + + // Called when creating a new MV. + def saveViewCountToFile(dbName: String): Unit = { + val data: mutable.Map[String, Array[Long]] = mutable.Map[String, Array[Long]]() + ViewMetadata.viewCnt.forEach { + (name, info) => + val db = name.split("\\.")(0) + if (db.equals(dbName)) { + data.put(name, info) + } + } + saveMapToDisk(dbName, VIEW_CNT_FILE, data, isAppend = false, lineFeed = false) + } + + def loadViewCount(): Unit = { + val dbs = getDBs + dbs.foreach { + db => + loadViewCount(db) + } + } + + def loadViewCount(dbName: String): Unit = { + // clear viewCnt info in dbName + val iterator = viewCnt.entrySet.iterator + while (iterator.hasNext) { + val entry = iterator.next + if (entry.getKey.split("\\.")(0) equals dbName) iterator.remove + } + + val viewCounts = mutable.Map[String, Array[Long]]().empty + viewCounts ++= loadDataFromDisk(dbName, VIEW_CNT_FILE, isTailLines = true, viewCounts) { + (preData, newData, modifyTime) => + for (data <- newData) { + val dataWithModifyTime = (data._1, data._2.slice(0, 2) ++ Array(modifyTime)) + preData += dataWithModifyTime + } + } + + // set view count into ViewMetadata.viewCnt + for (viewCount <- viewCounts) { + viewCnt.put(viewCount._1, viewCount._2) + } + } + + /** + * load data from disk. + * + * @param dbName Which directory in the metadata stores this data. + * @param fileName Which file in the metadata directory stores this data. + * @param isTailLines + * @param data Data to be stored and data is of type Map. + * @tparam K is the type of key for the Map + * @tparam V V is the type of value for the Map + * @return + */ + private def loadDataFromDisk[K: Manifest, V: Manifest]( + dbName: String, + fileName: String, + isTailLines: Boolean, + data: mutable.Map[K, V]) + (addNewDataToPreData: ( + mutable.Map[K, V], + mutable.Map[K, V], + Long) => Unit): mutable.Map[K, V] = { + + val dbPath = new Path(metadataPath, dbName) + val filePath = new Path(dbPath, fileName) + loadMapFromDisk(filePath, isTailLines, data)(addNewDataToPreData) + } + + private def loadMapFromDisk[K: Manifest, V: Manifest]( + filePath: Path, + isTailLines: Boolean, + data: mutable.Map[K, V]) + (addNewDataToPreData: ( + mutable.Map[K, V], + mutable.Map[K, V], + Long) => Unit): mutable.Map[K, V] = { + val newData = data.empty + if (!fs.exists(filePath)) { + return newData + } + var readLines = OmniMVPluginConfig.getConf.metadataIndexTailLines + val is = fs.open(filePath) + var pos = fs.getFileStatus(filePath).getLen - 1 + val modifyTime = fs.getFileStatus(filePath).getModificationTime + var lineReady = false + var bytes = mutable.Seq.empty[Char] + // tail the file + while (pos >= 0) { + is.seek(pos) + val readByte = is.readByte() + readByte match { + // \n + case SEPARATOR => + if (bytes.size != 0) { + lineReady = true + } + case _ => + bytes +:= readByte.toChar + } + pos -= 1 + + // find \n or file start + if (lineReady || pos < 0) { + val line = bytes.mkString("") + val properties = Json(DefaultFormats) + .read[mutable.Map[K, V]](line) + addNewDataToPreData(newData, properties, modifyTime) + lineReady = false + bytes = mutable.Seq.empty[Char] + + if (isTailLines) { + readLines -= 1 + if (readLines <= 0) { + return newData + } + } + } + } + is.close() + newData + } + + private def loadStrFromDisk(filePath: Path): String = { + if (!fs.exists(filePath)) { + return "" + } + val in = fs.open(filePath) + val ciphertext = IOUtils.toByteArray(in).map(_.toChar).mkString("") + in.close() + ciphertext + } + + /** + * save data to disk. + * Metadata information is classified by DBNames. + * + * @param dbName Which directory in the metadata stores this data. + * @param fileName Which file in the metadata directory stores this data. + * @param data Data to be stored and data is of type Map. + * @tparam K K is the type of key for the Map + * @tparam V V is the type of value for the Map + */ + def saveMapToDisk[K: Manifest, V: Manifest]( + dbName: String, + fileName: String, + data: mutable.Map[K, V], + isAppend: Boolean, + lineFeed: Boolean): Unit = { + val dbPath = new Path(metadataPath, dbName) + val file = new Path(dbPath, fileName) + val os = if (!fs.exists(file) || !isAppend || fs.isInstanceOf[LocalFileSystem]) { + fs.create(file, true) + } else { + fs.append(file) + } + // append + val jsonFile = Json(DefaultFormats).write(data) + os.write(jsonFile.getBytes()) + // line feed + if (lineFeed) { + os.write(SEPARATOR) + } + os.close() + } + + private def saveStrToDisk( + file: Path, + data: String, + isAppend: Boolean): Unit = { + val os = if (!fs.exists(file) || !isAppend || fs.isInstanceOf[LocalFileSystem]) { + fs.create(file, true) + } else { + fs.append(file) + } + IOUtils.write(data, os) + os.close() + } + + /** + * If "spark.sql.omnimv.dbs" specifies databases, + * the databases are used. + * Otherwise, all databases in the metadata directory are obtained by default. + * + * @return + */ + def getDBs: Set[String] = { + if (OmniMVPluginConfig.getConf.omniMVDB.nonEmpty) { + OmniMVPluginConfig.getConf.omniMVDB + .split(",").map(_.toLowerCase(Locale.ROOT)).toSet + } else { + fs.listStatus(metadataPath).map(_.getPath.getName).toSet + } + } + + // just for test. + def getViewCntPath: String = { + VIEW_CNT_FILE + } + + def getDefaultDatabase: String = { + DEFAULT_DATABASE + } + + def saveWashOutTimestamp(): Unit = { + val map = mutable.Map[String, Long]() + if (washOutTimestamp.isDefined) { + map += ("washOutTimestamp" -> washOutTimestamp.get) + } + val str = KryoSerDeUtil.serializeToStr(kryoSerializer, map) + saveStrToDisk(new Path(metadataPath, WASH_OUT_TIMESTAMP), str, isAppend = false) + } + + def loadWashOutTimestamp(): Unit = { + val ciphertext = loadStrFromDisk(new Path(metadataPath, WASH_OUT_TIMESTAMP)) + val timestamp = KryoSerDeUtil.deserializeFromStr[mutable.Map[String, Long]]( + kryoSerializer, ciphertext) + if (timestamp != null) { + washOutTimestamp = timestamp.get(WASH_OUT_TIMESTAMP) + } + } + + def getViewCntModifyTime(viewCnt: ConcurrentHashMap[String, Array[Long]]): Option[Long] = { + viewCnt.forEach { + (_, value) => + return Some(value(2)) + } + Option.empty + } + + def getViewDependsTableTime(viewName: String): Map[String, String] = { + var catalogTables: Set[CatalogTable] = Set() + viewToContainsTables.get(viewName).map(_.logicalPlan) + .foreach(plan => catalogTables ++= extractCatalogTablesOnly(plan)) + getViewDependsTableTime(catalogTables) + } + + def getViewDependsTableTime(catalogTables: Set[CatalogTable]): Map[String, String] = { + var viewDependsTableTime = Map[String, String]() + catalogTables.foreach { catalogTable => + viewDependsTableTime += (formatViewName(catalogTable.identifier) -> + getPathTime(catalogTable.storage.locationUri.get).toString) + } + viewDependsTableTime + } + + def getViewDependsTableTimeStr(viewQueryPlan: LogicalPlan): String = { + val str: String = Json(DefaultFormats).write( + getViewDependsTableTime(extractCatalogTablesOnly(viewQueryPlan))) + str + } + + def getLastViewDependsTableTime(viewName: String): Map[String, String] = { + Json(DefaultFormats).read[Map[String, String]]( + viewProperties.get(viewName)(MV_LATEST_UPDATE_TIME)) + } + + def getPathTime(uri: URI): Long = { + fs.getFileStatus(new Path(uri)).getModificationTime + } + + def checkViewDataReady(viewName: String): Unit = { + if (REFRESH_STAT equals BUSY) { + return + } + val lastTime = getLastViewDependsTableTime(viewName) + val nowTime = getViewDependsTableTime(viewName) + if (lastTime != nowTime) { + REFRESH_STAT = BUSY + RewriteTime.withTimeStat("REFRESH MV") { + val sqlText = spark.sparkContext.getLocalProperty(SPARK_JOB_DESCRIPTION) + RewriteHelper.enableSqlLog() + spark.sql(s"REFRESH MATERIALIZED VIEW $viewName;") + RewriteHelper.disableSqlLog() + spark.sparkContext.setJobDescription(sqlText) + val newProperty = ViewMetadata.viewProperties.get(viewName) + + (MV_LATEST_UPDATE_TIME -> Json(DefaultFormats).write(nowTime)) + ViewMetadata.viewProperties.put(viewName, newProperty) + val viewDB = viewName.split("\\.")(0) + saveViewMetadataToFile(viewDB, viewName) + } + val updateReason = nowTime.toSeq.filter { kv => + !lastTime.contains(kv._1) || lastTime(kv._1) != kv._2 + }.toString() + logBasedOnLevel(s"REFRESH MATERIALIZED VIEW $viewName; " + + s"for depends table has updated $updateReason") + REFRESH_STAT = IDLE + } + } +} diff --git a/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/lock/FileLock.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/lock/FileLock.scala new file mode 100644 index 0000000000000000000000000000000000000000..0bab983a1a95f48a9d205ba749d2a583777058fc --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/lock/FileLock.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.util.lock + +import java.io.{FileNotFoundException, IOException} + +import org.apache.hadoop.fs.{FileSystem, Path} + + +case class FileLock(fs: FileSystem, lockFile: Path) { + def isLocked: Boolean = { + if (fs.exists(lockFile)) { + return true + } + false + } + + def lock: Boolean = { + var res = true + try { + val out = fs.create(lockFile, false) + out.close() + } catch { + case _ => + res = false + } + res + } + + def unLock: Boolean = { + try { + fs.delete(lockFile, true) + } catch { + case _ => + throw new IOException("[OmniMVAtomic] unlock failed.") + } + } + + /** + * Determine whether the lock times out. + * The default timeout period is 1 minute. + */ + def isTimeout: Boolean = { + val curTime = System.currentTimeMillis() + var modifyTime = curTime + try { + modifyTime = fs.getFileStatus(lockFile).getModificationTime + } catch { + case e: FileNotFoundException => + // It is not an atomic operation, so it is normal for this exception to exist. + } + val duration = curTime - modifyTime + // 60000 sec equal 1 minute + val threshold = 60000 + if (threshold < duration) { + return true + } + false + } + + /** + * When a timeout occurs, other tasks try to release the lock. + */ + def releaseLock(): Unit = { + try { + fs.delete(lockFile, true) + } catch { + case _: Throwable => + } + } +} diff --git a/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/lock/OmniMVAtomic.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/lock/OmniMVAtomic.scala new file mode 100644 index 0000000000000000000000000000000000000000..54e643cdaa4218bfa2d34d4a79d3d86c768d6b46 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/lock/OmniMVAtomic.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.util.lock + +import com.huawei.boostkit.spark.util.RewriteLogger + + +object OmniMVAtomic extends RewriteLogger { + // func atomicity is guaranteed through file locks + private def atomicFunc(fileLock: FileLock)(func: () => Unit): Boolean = { + if (fileLock.isLocked || !fileLock.lock) { + return false + } + try { + func() + } catch { + case e: Throwable => + throw e + } finally { + fileLock.unLock + } + true + } + + private def timeoutReleaseLock(fileLock: FileLock): Unit = { + if (fileLock.isTimeout) { + logError("[Omni Atomic] lock expired.") + fileLock.releaseLock() + } + } + + // The spin waits or gets the lock to perform the operation + def funcWithSpinLock(fileLock: FileLock)(func: () => Unit): Unit = { + while (!atomicFunc(fileLock)(func)) { + logInfo("[Omni Atomic] wait for lock.") + timeoutReleaseLock(fileLock) + } + } +} diff --git a/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/serde/KryoSerDeUtil.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/serde/KryoSerDeUtil.scala new file mode 100644 index 0000000000000000000000000000000000000000..6c554743abeadd33143284f1cce919657091c08d --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/serde/KryoSerDeUtil.scala @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.util.serde + +import com.esotericsoftware.kryo.io.Input +import java.io.ByteArrayOutputStream +import java.util.Base64 +import org.apache.hadoop.fs.Path + +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteTime +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat + +object KryoSerDeUtil { + + /** + * serialize object to byte array + * + * @param kryoSerializer kryoSerializer + * @param obj obj + * @tparam T obj type + * @return serialized byte array + */ + def serialize[T](kryoSerializer: KryoSerializer, obj: T): Array[Byte] = { + val kryo = kryoSerializer.newKryo() + val baos = new ByteArrayOutputStream() + val output = kryoSerializer.newKryoOutput() + output.setOutputStream(baos) + + kryo.writeClassAndObject(output, obj) + output.close() + baos.toByteArray + } + + /** + * serialize object to string + * + * @param kryoSerializer kryoSerializer + * @param obj obj + * @tparam T obj type + * @return serialized string + */ + def serializeToStr[T](kryoSerializer: KryoSerializer, obj: T): String = { + val byteArray = serialize[T](kryoSerializer, obj) + Base64.getEncoder.encodeToString(byteArray) + } + + /** + * deserialize byte array to object + * + * @param kryoSerializer kryoSerializer + * @param byteArray byteArray + * @tparam T obj type + * @return deserialized object + */ + def deserialize[T](kryoSerializer: KryoSerializer, byteArray: Array[Byte]): T = { + RewriteTime.withTimeStat("deserialize.C") { + val kryo = RewriteTime.withTimeStat("deserialize.newKryo.C") { + kryoSerializer.newKryo() + } + val input = new Input() + if (byteArray == null || byteArray.size == 0) { + input.setBuffer(new Array[Byte](4096)) + } else { + input.setBuffer(byteArray) + } + + val obj = RewriteTime.withTimeStat("deserialize.readClassAndObject.C") { + kryo.readClassAndObject(input) + } + obj.asInstanceOf[T] + } + } + + /** + * deserialize string to object + * + * @param kryoSerializer kryoSerializer + * @param str str + * @tparam T obj type + * @return deserialized object + */ + def deserializeFromStr[T](kryoSerializer: KryoSerializer, str: String): T = { + val byteArray = RewriteTime.withTimeStat("Base64.getDecoder.decode.C") { + Base64.getDecoder.decode(str) + } + deserialize[T](kryoSerializer, byteArray) + } + + /** + * serialize logicalPlan to string + * + * @param kryoSerializer kryoSerializer + * @param plan plan + * @return serialized string + */ + def serializePlan(kryoSerializer: KryoSerializer, plan: LogicalPlan): String = { + val wrappedPlan = wrap(plan) + serializeToStr[LogicalPlan](kryoSerializer, wrappedPlan) + } + + /** + * deserialize string to logicalPlan + * + * @param kryoSerializer kryoSerializer + * @param spark spark + * @param serializedPlan serializedPlan + * @return logicalPlan + */ + def deserializePlan( + kryoSerializer: KryoSerializer, spark: SparkSession, serializedPlan: String): LogicalPlan = { + val wrappedPlan = deserializeFromStr[LogicalPlan](kryoSerializer, serializedPlan) + unwrap(spark, wrappedPlan) + } + + /** + * wrap logicalPlan if cannot serialize + * + * @param plan plan + * @return wrapped plan + */ + def wrap(plan: LogicalPlan): LogicalPlan = { + // subqeury contains plan + val newPlan = plan.transformAllExpressions { + case e: ScalarSubquery => + ScalarSubqueryWrapper(wrap(e.plan), e.children, e.exprId) + case e: ListQuery => + ListQueryWrapper(wrap(e.plan), e.children, e.exprId, e.childOutputs) + case e: InSubquery => + InSubqueryWrapper( + e.values, + ListQueryWrapper( + wrap(e.query.plan), + e.query.children, + e.query.exprId, + e.query.childOutputs)) + case e: Exists => + ExistsWrapper(wrap(e.plan), e.children, e.exprId) + case e: ScalaUDF => + ScalaUDFWrapper( + e.function, + e.dataType, + e.children, + e.inputEncoders, + e.outputEncoder, + e.udfName, + e.nullable, + e.udfDeterministic) + } + newPlan.transform { + case p: With => + With(wrap(p.child), p.cteRelations.map { + case (r, s) => (r, SubqueryAlias(s.alias, wrap(s.child))) + }) + case p: Intersect => + IntersectWrapper(wrap(p.left), wrap(p.right), p.isAll) + case p: Except => + ExceptWrapper(wrap(p.left), wrap(p.right), p.isAll) + case LogicalRelation( + HadoopFsRelation( + location: FileIndex, + partitionSchema, + dataSchema, + bucketSpec, + fileFormat, + options), + output, + catalogTable, + isStreaming) => + LogicalRelationWrapper( + HadoopFsRelationWrapper( + wrapFileIndex(location), + partitionSchema, + dataSchema, + bucketSpec, + wrapFileFormat(fileFormat), + options), + output, + catalogTable, + isStreaming) + } + } + + /** + * unwrap logicalPlan to original logicalPlan + * + * @param spark spark + * @param plan plan + * @return original logicalPlan + */ + def unwrap(spark: SparkSession, plan: LogicalPlan): LogicalPlan = { + RewriteTime.withTimeStat("unwrap.C") { + val newPlan = plan.transform { + case p: With => + With(unwrap(spark, p.child), p.cteRelations.map { + case (r, s) => (r, SubqueryAlias(s.alias, unwrap(spark, s.child))) + }) + case p: IntersectWrapper => + Intersect(unwrap(spark, p.left), unwrap(spark, p.right), p.isAll) + case p: ExceptWrapper => + Except(unwrap(spark, p.right), unwrap(spark, p.right), p.isAll) + case LogicalRelationWrapper( + HadoopFsRelationWrapper( + location: FileIndex, + partitionSchema, + dataSchema, + bucketSpec, + fileFormat, + options), + output, + catalogTable, + isStreaming) => + LogicalRelation( + HadoopFsRelation( + unwrapFileIndex(spark, location), + partitionSchema, + dataSchema, + bucketSpec, + unwrapFileFormat(fileFormat), + options)(spark), + output, + catalogTable, + isStreaming) + case h: HiveTableRelation => + h.copy(prunedPartitions = None) + } + + newPlan.transformAllExpressions { + case e: ScalarSubqueryWrapper => + ScalarSubquery(unwrap(spark, e.plan), e.children, e.exprId) + case e: ListQueryWrapper => + ListQueryWrapper(unwrap(spark, e.plan), e.children, e.exprId, e.childOutputs) + case e: InSubqueryWrapper => + InSubquery( + e.values, + ListQuery( + unwrap(spark, e.query.plan), + e.query.children, + e.query.exprId, + e.query.childOutputs)) + case e: ExistsWrapper => + Exists(unwrap(spark, e.plan), e.children, e.exprId) + case e: ScalaUDFWrapper => + ScalaUDF( + e.function, + e.dataType, + e.children, + e.inputEncoders, + e.outputEncoder, + e.udfName, + e.nullable, + e.udfDeterministic + ) + } + } + } + + def wrapFileIndex(fileIndex: FileIndex): FileIndex = { + fileIndex match { + case location: InMemoryFileIndex => + InMemoryFileIndexWrapper(location.rootPaths.map(path => path.toString)) + case location: CatalogFileIndex => + CatalogFileIndexWrapper(location.table, location.sizeInBytes) + case other => + other + } + } + + def unwrapFileIndex(spark: SparkSession, fileIndex: FileIndex): FileIndex = { + fileIndex match { + case location: InMemoryFileIndexWrapper => + new InMemoryFileIndex( + spark, + location.rootPathsSpecified.map(path => new Path(path)), + Map(), + None) + case location: CatalogFileIndexWrapper => + new CatalogFileIndex( + spark, + location.table, + location.sizeInBytes) + case other => + other + } + } + + def wrapFileFormat(fileFormat: FileFormat): FileFormat = { + fileFormat match { + case _: CSVFileFormat => CSVFileFormatWrapper + case _: JsonFileFormat => JsonFileFormatWrapper + case other => other + } + } + + def unwrapFileFormat(fileFormat: FileFormat): FileFormat = { + fileFormat match { + case CSVFileFormatWrapper => new CSVFileFormat + case JsonFileFormatWrapper => new JsonFileFormat + case other => other + } + } +} diff --git a/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/serde/LogicalPlanWrapper.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/serde/LogicalPlanWrapper.scala new file mode 100644 index 0000000000000000000000000000000000000000..f101b57efa35c1db1c946dbec8cc0d4e4df9e047 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/com/huawei/boostkit/spark/util/serde/LogicalPlanWrapper.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.util.serde + +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce.Job + +import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, LeafNode, LogicalPlan} +import org.apache.spark.sql.execution.FileRelation +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.types.{DataType, StructType} + +trait LogicalPlanWrapper + +/** + * parent class no default constructor + */ +trait NoDefaultConstructor extends LogicalPlanWrapper + +/** + * class contains variable like SparkSession,Configuration + */ +trait InMemoryStates extends LogicalPlanWrapper + +abstract class SubqueryExpressionWrapper + extends Unevaluable with NoDefaultConstructor { + override def nullable: Boolean = throw new UnsupportedOperationException() + + override def dataType: DataType = throw new UnsupportedOperationException() +} + +case class ScalarSubqueryWrapper(plan: LogicalPlan, children: Seq[Expression], exprId: ExprId) + extends SubqueryExpressionWrapper { + override def dataType: DataType = plan.schema.fields.head.dataType + + override def toString: String = s"scalar-subquery-wrapper#${exprId.id}" +} + +case class ListQueryWrapper(plan: LogicalPlan, children: Seq[Expression], exprId: ExprId, + childOutputs: Seq[Attribute]) + extends SubqueryExpressionWrapper { + override def toString(): String = s"list-wrapper#${exprId.id}" +} + +case class InSubqueryWrapper(values: Seq[Expression], query: ListQueryWrapper) + extends Predicate with Unevaluable { + override def children: Seq[Expression] = values :+ query + + override def nullable: Boolean = throw new UnsupportedOperationException() +} + +case class ExistsWrapper(plan: LogicalPlan, children: Seq[Expression], exprId: ExprId) + extends SubqueryExpressionWrapper { + override def toString(): String = s"exists-wrapper#${exprId.id}" +} + +case class ScalaUDFWrapper( + function: AnyRef, + dataType: DataType, + children: Seq[Expression], + inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil, + outputEncoder: Option[ExpressionEncoder[_]] = None, + udfName: Option[String] = None, + nullable: Boolean = true, + udfDeterministic: Boolean = true) + extends Expression with Unevaluable with NoDefaultConstructor + +case class IntersectWrapper( + left: LogicalPlan, + right: LogicalPlan, + isAll: Boolean) + extends BinaryNode with NoDefaultConstructor { + + override def output: Seq[Attribute] = + left.output.zip(right.output).map { case (leftAttr, rightAttr) => + leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) + } +} + +case class ExceptWrapper( + left: LogicalPlan, + right: LogicalPlan, + isAll: Boolean) + extends BinaryNode with NoDefaultConstructor { + + override def output: Seq[Attribute] = left.output +} + +case class InMemoryFileIndexWrapper(rootPathsSpecified: Seq[String]) + extends FileIndex with InMemoryStates { + override def rootPaths: Seq[Path] = throw new UnsupportedOperationException() + + override def listFiles( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] = + throw new UnsupportedOperationException() + + override def inputFiles: Array[String] = throw new UnsupportedOperationException() + + override def refresh(): Unit = throw new UnsupportedOperationException() + + override def sizeInBytes: Long = throw new UnsupportedOperationException() + + override def partitionSchema: StructType = throw new UnsupportedOperationException() +} + +case class CatalogFileIndexWrapper(table: CatalogTable, + override val sizeInBytes: Long) + extends FileIndex with InMemoryStates { + override def rootPaths: Seq[Path] = throw new UnsupportedOperationException() + + override def listFiles( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] = + throw new UnsupportedOperationException() + + override def inputFiles: Array[String] = throw new UnsupportedOperationException() + + override def refresh(): Unit = throw new UnsupportedOperationException() + + override def partitionSchema: StructType = throw new UnsupportedOperationException() +} + +case class HadoopFsRelationWrapper( + location: FileIndex, + partitionSchema: StructType, + dataSchema: StructType, + bucketSpec: Option[BucketSpec], + fileFormat: FileFormat, + options: Map[String, String]) + extends BaseRelation with FileRelation with InMemoryStates { + override def sqlContext: SQLContext = throw new UnsupportedOperationException() + + override def schema: StructType = throw new UnsupportedOperationException() + + override def inputFiles: Array[String] = throw new UnsupportedOperationException() +} + +case class LogicalRelationWrapper( + relation: BaseRelation, + output: Seq[AttributeReference], + catalogTable: Option[CatalogTable], + override val isStreaming: Boolean) + extends LeafNode with InMemoryStates + +abstract class FileFormatWrapper extends FileFormat { + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = + throw new UnsupportedOperationException() + + override def prepareWrite(sparkSession: SparkSession, + job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = + throw new UnsupportedOperationException() +} + +case object CSVFileFormatWrapper extends FileFormatWrapper + +case object JsonFileFormatWrapper extends FileFormatWrapper diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OmniCacheToSparkAdapter.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OmniMVToSparkAdapter.scala similarity index 95% rename from omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OmniCacheToSparkAdapter.scala rename to omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OmniMVToSparkAdapter.scala index 6fff4c6a058947045ffbd8cd2995bd9c91a91802..a8eebd1799078cf60a9ffdf432ad33f245baa899 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OmniCacheToSparkAdapter.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OmniMVToSparkAdapter.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.SparkOptimizer import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.types.StructType -object OmniCacheToSparkAdapter extends SQLConfHelper with RewriteLogger { +object OmniMVToSparkAdapter extends SQLConfHelper with RewriteLogger { def buildCatalogTable( table: TableIdentifier, @@ -79,7 +79,7 @@ object OmniCacheToSparkAdapter extends SQLConfHelper with RewriteLogger { } } -case class OmniCacheOptimizer(session: SparkSession, optimizer: Optimizer) extends +case class OmniMVOptimizer(session: SparkSession, optimizer: Optimizer) extends SparkOptimizer(session.sessionState.catalogManager, session.sessionState.catalog, session.sessionState.experimentalMethods) { diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/AbstractMaterializedViewRule.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/AbstractMaterializedViewRule.scala similarity index 67% rename from omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/AbstractMaterializedViewRule.scala rename to omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/AbstractMaterializedViewRule.scala index 49028767b1bba8d7ac17a7856738210e7882e33b..0df3561a288d9672faece7eb2dea4230865db170 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/AbstractMaterializedViewRule.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/AbstractMaterializedViewRule.scala @@ -18,16 +18,19 @@ package org.apache.spark.sql.catalyst.optimizer.rules import com.google.common.collect._ -import com.huawei.boostkit.spark.conf.OmniCachePluginConfig +import com.huawei.boostkit.spark.conf.OmniMVPluginConfig import com.huawei.boostkit.spark.util._ +import com.huawei.boostkit.spark.util.ViewMetadata._ +import com.huawei.boostkit.spark.util.lock.{FileLock, OmniMVAtomic} import org.apache.calcite.util.graph.{DefaultEdge, Graphs} +import org.apache.hadoop.fs.Path import scala.collection.{mutable, JavaConverters} import scala.util.control.Breaks import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -37,140 +40,181 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) /** * try match the queryPlan and viewPlan ,then rewrite by viewPlan * - * @param topProject queryTopProject - * @param plan queryPlan - * @param usingMvs usingMvs + * @param topProject queryTopProject + * @param plan queryPlan + * @param usingMvInfos usingMvInfos + * @param candidateViewPlan candidateViewPlan * @return performedPlan */ def perform(topProject: Option[Project], plan: LogicalPlan, - usingMvs: mutable.Set[String]): LogicalPlan = { + usingMvInfos: mutable.Set[(String, String)], + candidateViewPlan: ViewMetadataPackageType): LogicalPlan = { var finalPlan = if (topProject.isEmpty) plan else topProject.get + logDetail(s"enter rule:${this.getClass.getName} perform for plan:$finalPlan") - if (ViewMetadata.status == ViewMetadata.STATUS_LOADING) { - return finalPlan - } - RewriteTime.withTimeStat("viewMetadata") { - ViewMetadata.init(sparkSession) - } // 1.check query sql is match current rule if (ViewMetadata.isEmpty || !isValidPlan(plan)) { + if (ViewMetadata.isEmpty) { + logDetail("ViewMetadata.isEmpty") + } else { + logDetail(s"queryPlan isValidPlan") + } return finalPlan } // 2.extract tablesInfo from queryPlan and replace the AttributeReference // in plan using tableAttr val (queryExpr, queryTables) = extractTables(finalPlan) + logDetail(s"queryTables:$queryTables") // 3.use all tables to fetch views(may match) from ViewMetaData - val candidateViewPlans = RewriteTime.withTimeStat("getApplicableMaterializations") { - getApplicableMaterializations(queryTables.map(t => t.tableName)) - .filter(x => !OmniCachePluginConfig.isMVInUpdate(x._2)) - } - if (candidateViewPlans.isEmpty) { - return finalPlan - } // continue for curPlanLoop,mappingLoop val curPlanLoop = new Breaks val mappingLoop = new Breaks // 4.iterate views,try match and rewrite - for ((viewName, srcViewTablePlan, srcViewQueryPlan) <- candidateViewPlans) { - curPlanLoop.breakable { - // 4.1.check view query sql is match current rule - if (!isValidPlan(srcViewQueryPlan)) { - curPlanLoop.break() - } + val (viewName, srcViewTablePlan, srcViewQueryPlan) = candidateViewPlan + val viewDatabase = RewriteHelper.getMVDatabase(srcViewTablePlan) + curPlanLoop.breakable { + logDetail(s"iterate view:$viewName, viewTablePlan:$srcViewTablePlan, " + + s"viewQueryPlan:$srcViewQueryPlan") + // 4.1.check view query sql is match current rule + if (!isValidPlan(srcViewQueryPlan)) { + logDetail(s"viewPlan isValidPlan:$srcViewQueryPlan") + curPlanLoop.break() + } - OmniCachePluginConfig.getConf.setCurMatchMV(viewName) - // 4.2.view plans - var viewTablePlan = srcViewTablePlan - var viewQueryPlan = srcViewQueryPlan - var topViewProject: Option[Project] = None - var viewQueryExpr: LogicalPlan = viewQueryPlan - viewQueryPlan match { - case p: Project => - topViewProject = Some(p) - viewQueryPlan = p.child - viewQueryExpr = p - case _ => - } + OmniMVPluginConfig.getConf.setCurMatchMV(viewName) + // 4.2.view plans + var viewTablePlan = aliasViewTablePlan(srcViewTablePlan, queryExpr) + var viewQueryPlan = srcViewQueryPlan + var topViewProject: Option[Project] = None + var viewQueryExpr: LogicalPlan = viewQueryPlan + viewQueryPlan match { + case p: Project => + topViewProject = Some(p) + viewQueryPlan = p.child + viewQueryExpr = p + case _ => + } + + // 4.3.extract tablesInfo from viewPlan + val viewTables = ViewMetadata.viewToContainsTables.get(viewName) - // 4.3.extract tablesInfo from viewPlan - val viewTables = ViewMetadata.viewToContainsTables.get(viewName) + // 4.4.compute the relation of viewTableInfo and queryTableInfo + // 4.4.1.queryTableInfo containsAll viewTableInfo + if (!viewTables.subsetOf(queryTables)) { + logDetail(s"viewTables is not subsetOf queryTables") + curPlanLoop.break() + } - // 4.4.compute the relation of viewTableInfo and queryTableInfo - // 4.4.1.queryTableInfo containsAll viewTableInfo - if (!viewTables.subsetOf(queryTables)) { + // 4.4.2.queryTableInfo!=viewTableInfo, need do join compensate + val needCompensateTables = queryTables -- viewTables + logDetail(s"needCompensateTables:$needCompensateTables") + if (needCompensateTables.nonEmpty) { + val sortedNeedCompensateTables = needCompensateTables.toSeq.sortWith { + (t1: TableEqual, t2: TableEqual) => + t1.seq < t2.seq + } + logDetail(f"sortedNeedCompensateTables:$sortedNeedCompensateTables") + val newViewPlans = compensateViewPartial(viewTablePlan, + viewQueryExpr, topViewProject, sortedNeedCompensateTables) + if (newViewPlans.isEmpty) { curPlanLoop.break() } + val (newViewTablePlan, newViewQueryPlan, newTopViewProject) = newViewPlans.get + viewTablePlan = newViewTablePlan + viewQueryPlan = newViewQueryPlan + viewQueryExpr = newViewQueryPlan + topViewProject = newTopViewProject + } - // 4.4.2.queryTableInfo!=viewTableInfo, need do join compensate - val needCompensateTables = queryTables -- viewTables - if (needCompensateTables.nonEmpty) { - val newViewPlans = compensateViewPartial(viewTablePlan, - viewQueryPlan, topViewProject, needCompensateTables) - if (newViewPlans.isEmpty) { - curPlanLoop.break() + // 4.5.extractPredictExpressions from viewQueryPlan and mappedQueryPlan + val queryPredictExpression = RewriteTime.withTimeStat("extractPredictExpressions") { + extractPredictExpressions(queryExpr, EMPTY_BIMAP) + } + logDetail(s"queryPredictExpression:$queryPredictExpression") + + val viewProjectList = extractTopProjectList(viewQueryExpr) + val viewTableAttrs = viewTablePlan.output + + // 4.6.if a table emps used >=2 times in a sql (query and view) + // we should try the combination,switch the seq + // view:SELECT V1.locationid,V2.empname FROM emps V1 JOIN emps V2 + // ON V1.deptno='1' AND V2.deptno='2' AND V1.empname = V2.empname; + // query:SELECT V2.locationid,V1.empname FROM emps V1 JOIN emps V2 + // ON V1.deptno='2' AND V2.deptno='1' AND V1.empname = V2.empname; + val flatListMappings: Seq[BiMap[String, String]] = generateTableMappings(queryTables) + + flatListMappings.foreach { queryToViewTableMapping => + mappingLoop.breakable { + val inverseTableMapping = queryToViewTableMapping.inverse() + logDetail(s"iterate queryToViewTableMapping:$inverseTableMapping") + val viewPredictExpression = RewriteTime.withTimeStat("extractPredictExpressions") { + extractPredictExpressions(viewQueryExpr, + inverseTableMapping) + } + logDetail(s"viewPredictExpression:$viewPredictExpression") + + // 4.7.compute compensationPredicates between viewQueryPlan and queryPlan + var newViewTablePlan = RewriteTime.withTimeStat("computeCompensationPredicates") { + computeCompensationPredicates(viewTablePlan, + queryPredictExpression, viewPredictExpression, inverseTableMapping, + viewPredictExpression._1.getEquivalenceClassesMap, + viewProjectList, viewTableAttrs) + } + logDetail(s"computeCompensationPredicates plan:$newViewTablePlan") + // 4.8.compensationPredicates isEmpty, because view's row data cannot satisfy query + if (newViewTablePlan.isEmpty) { + logDetail("computeCompensationPredicates plan isEmpty") + mappingLoop.break() } - val (newViewTablePlan, newViewQueryPlan, newTopViewProject) = newViewPlans.get - viewTablePlan = newViewTablePlan - viewQueryPlan = newViewQueryPlan - viewQueryExpr = newViewQueryPlan - topViewProject = newTopViewProject - } - // 4.5.extractPredictExpressions from viewQueryPlan and mappedQueryPlan - val queryPredictExpression = RewriteTime.withTimeStat("extractPredictExpressions") { - extractPredictExpressions(queryExpr, EMPTY_BIMAP) - } + // 4.9.use viewTablePlan(join compensated), query project, + // compensationPredicts to rewrite final plan - val viewProjectList = extractTopProjectList(viewQueryExpr) - val viewTableAttrs = viewTablePlan.output - - // 4.6.if a table emps used >=2 times in a sql (query and view) - // we should try the combination,switch the seq - // view:SELECT V1.locationid,V2.empname FROM emps V1 JOIN emps V2 - // ON V1.deptno='1' AND V2.deptno='2' AND V1.empname = V2.empname; - // query:SELECT V2.locationid,V1.empname FROM emps V1 JOIN emps V2 - // ON V1.deptno='2' AND V2.deptno='1' AND V1.empname = V2.empname; - val flatListMappings: Seq[BiMap[String, String]] = generateTableMappings(queryTables) - - flatListMappings.foreach { queryToViewTableMapping => - mappingLoop.breakable { - val inverseTableMapping = queryToViewTableMapping.inverse() - val viewPredictExpression = RewriteTime.withTimeStat("extractPredictExpressions") { - extractPredictExpressions(viewQueryExpr, - inverseTableMapping) + newViewTablePlan = RewriteTime.withTimeStat("rewriteView") { + rewriteView(newViewTablePlan.get, viewQueryExpr, + queryExpr, inverseTableMapping, + queryPredictExpression._1.getEquivalenceClassesMap, + viewProjectList, viewTableAttrs) + } + logDetail(s"rewriteView plan:$newViewTablePlan") + if (newViewTablePlan.isEmpty || !RewriteHelper.checkAttrsValid(newViewTablePlan.get)) { + logDetail("rewriteView plan isEmpty") + mappingLoop.break() + } + assert(viewDatabase.isDefined) + if (RewriteHelper.containsMV(newViewTablePlan.get)) { + // atomic update ViewMetadata.viewCnt + val dbName = viewName.split("\\.")(0) + val dbPath = new Path(metadataPath, dbName) + val dbViewCnt = new Path(dbPath, VIEW_CNT_FILE) + val fileLock = FileLock(fs, new Path(dbPath, VIEW_CNT_FILE_LOCK)) + OmniMVAtomic.funcWithSpinLock(fileLock) { + () => + if (fs.exists(dbViewCnt)) { + val curModifyTime = fs.getFileStatus(dbViewCnt).getModificationTime + if (ViewMetadata.getViewCntModifyTime(viewCnt).getOrElse(0L) != curModifyTime) { + loadViewCount(dbName) + } + } + val preViewCnt = ViewMetadata.viewCnt.getOrDefault( + viewName, Array[Long](0, System.currentTimeMillis())) + ViewMetadata.viewCnt.put( + viewName, Array(preViewCnt(0) + 1, System.currentTimeMillis())) + saveViewCountToFile(dbName) + loadViewCount(dbName) } + } - // 4.7.compute compensationPredicates between viewQueryPlan and queryPlan - var newViewTablePlan = RewriteTime.withTimeStat("computeCompensationPredicates") { - computeCompensationPredicates(viewTablePlan, - queryPredictExpression, viewPredictExpression, inverseTableMapping, - viewPredictExpression._1.getEquivalenceClassesMap, - viewProjectList, viewTableAttrs) - } - // 4.8.compensationPredicates isEmpty, because view's row data cannot satisfy query - if (newViewTablePlan.isEmpty) { - mappingLoop.break() - } + ViewMetadata.checkViewDataReady(viewName) - // 4.9.use viewTablePlan(join compensated), query project, - // compensationPredicts to rewrite final plan - newViewTablePlan = RewriteTime.withTimeStat("rewriteView") { - rewriteView(newViewTablePlan.get, viewQueryExpr, - queryExpr, inverseTableMapping, - queryPredictExpression._1.getEquivalenceClassesMap, - viewProjectList, viewTableAttrs) - } - if (newViewTablePlan.isEmpty) { - mappingLoop.break() - } - finalPlan = newViewTablePlan.get - usingMvs += viewName - return finalPlan - } + finalPlan = newViewTablePlan.get + finalPlan = sparkSession.sessionState.analyzer.execute(finalPlan) + usingMvInfos += viewName -> viewDatabase.get + return finalPlan } } } @@ -209,29 +253,31 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) } /** - * use all tables to fetch views(may match) from ViewMetaData + * basic check for outjoin * - * @param tableNames tableNames in query sql - * @return Seq[(viewName, viewTablePlan, viewQueryPlan)] + * @param logicalPlan LogicalPlan + * @return true:matched ; false:unMatched */ - def getApplicableMaterializations(tableNames: Set[String]): Seq[(String, - LogicalPlan, LogicalPlan)] = { - // viewName, viewTablePlan, viewQueryPlan - var viewPlans = Seq.empty[(String, LogicalPlan, LogicalPlan)] - val viewNames = mutable.Set.empty[String] - // 1.topological iterate graph - tableNames.foreach { tableName => - if (ViewMetadata.tableToViews.containsKey(tableName)) { - viewNames ++= ViewMetadata.tableToViews.get(tableName) - } - } - viewNames.foreach { viewName => - // 4.add plan info - val viewQueryPlan = ViewMetadata.viewToViewQueryPlan.get(viewName) - val viewTablePlan = ViewMetadata.viewToTablePlan.get(viewName) - viewPlans +:= (viewName, viewTablePlan, viewQueryPlan) + def isValidOutJoinLogicalPlan(logicalPlan: LogicalPlan): Boolean = { + logicalPlan.foreach { + case _: LogicalRelation => + case _: HiveTableRelation => + case _: Project => + case _: Filter => + case j: Join => + j.joinType match { + case _: Inner.type => + case _: LeftOuter.type => + case _: RightOuter.type => + case _: FullOuter.type => + case _: LeftSemi.type => + case _: LeftAnti.type => + case _ => return false + } + case _: SubqueryAlias => + case _ => return false } - viewPlans + true } /** @@ -264,7 +310,7 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) def compensateViewPartial(viewTablePlan: LogicalPlan, viewQueryPlan: LogicalPlan, topViewProject: Option[Project], - needTables: Set[TableEqual]): + needTables: Seq[TableEqual]): Option[(LogicalPlan, LogicalPlan, Option[Project])] = None /** @@ -323,6 +369,7 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) */ def generateEquivalenceClasses(queryEC: EquivalenceClasses, viewEC: EquivalenceClasses): Option[Expression] = { + logDetail(s"generateEquivalenceClasses queryEC:$queryEC, viewEC:$viewEC") // 1.all empty,valid if (queryEC.getEquivalenceClassesMap.isEmpty && viewEC.getEquivalenceClassesMap.isEmpty) { return Some(Literal.TrueLiteral) @@ -330,6 +377,7 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) // 2.query is empty,invalid if (queryEC.getEquivalenceClassesMap.isEmpty && viewEC.getEquivalenceClassesMap.nonEmpty) { + logDetail("queryEC.isEmpty && viewEC.nonEmpty") return None } @@ -338,7 +386,9 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) val viewEquivalenceClasses = viewEC.getEquivalenceClasses val mappingOp: Option[Multimap[Int, Int]] = extractPossibleMapping(queryEquivalenceClasses, viewEquivalenceClasses) + logDetail(s"queryEc to viewEc mappingOp:$mappingOp") if (mappingOp.isEmpty) { + logDetail("mappingOp.isEmpty") return None } val mapping = mappingOp.get @@ -423,9 +473,10 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) * @return compensate Expression */ def splitFilter(queryExpression: Expression, viewExpression: Expression): Option[Expression] = { + logDetail(s"splitFilter for queryExpression:$queryExpression, viewExpression:$viewExpression") // 1.canonicalize expression,main for reorder - val queryExpression2 = RewriteHelper.canonicalize(ExprSimplifier.simplify(queryExpression)) - val viewExpression2 = RewriteHelper.canonicalize(ExprSimplifier.simplify(viewExpression)) + val queryExpression2 = ExprSimplifier.simplify(queryExpression) + val viewExpression2 = ExprSimplifier.simplify(viewExpression) // 2.or is residual predicts,this main deal residual predicts val z = splitOr(queryExpression2, viewExpression2) @@ -440,15 +491,18 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) // 4.viewExpression2 and not(queryExpression2) val x = andNot(viewExpression2, queryExpression2) + logDetail(s"view andNot query:$x") // then check some absolutely invalid situation if (mayBeSatisfiable(x)) { // 4.1.queryExpression2 and viewExpression2 val x2 = ExprOptUtil.composeConjunctions( Seq(queryExpression2, viewExpression2), nullOnEmpty = false) + logDetail(s"query and view :$x2") // 4.2.canonicalize - val r = RewriteHelper.canonicalize(ExprSimplifier.simplify(x2)) + val r = ExprSimplifier.simplify(x2) if (ExprOptUtil.isAlwaysFalse(r)) { + logDetail(s"query and view isAlwaysFalse:$r") return None } @@ -459,7 +513,9 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) val residue = (conjs -- views).map(_.expression).toSeq return Some(ExprOptUtil.composeConjunctions(residue, nullOnEmpty = false)) } + logDetail(s"query != (query and view):$queryExpression2 != $r") } + logDetail(s"view andNot query not satisfy") None } @@ -599,7 +655,9 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) // 1.compute equalColumnCompensation val compensationColumnsEquiPredicts = generateEquivalenceClasses( queryPredict._1, viewPredict._1) + logDetail(s"compensationColumnsEquiPredicts:$compensationColumnsEquiPredicts") if (compensationColumnsEquiPredicts.isEmpty) { + logDetail("compensationColumnsEquiPredicts.isEmpty") return None } @@ -611,7 +669,9 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) val compensationRangePredicts = splitFilter( mergeConjunctiveExpressions(queryRangePredicts), mergeConjunctiveExpressions(viewRangePredicts)) + logDetail(s"compensationRangePredicts:$compensationRangePredicts") if (compensationRangePredicts.isEmpty) { + logDetail("compensationRangePredicts.isEmpty") return None } @@ -623,7 +683,9 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) val compensationResidualPredicts = splitFilter( mergeConjunctiveExpressions(queryResidualPredicts), mergeConjunctiveExpressions(viewResidualPredicts)) + logDetail(s"compensationResidualPredicts:$compensationResidualPredicts") if (compensationResidualPredicts.isEmpty) { + logDetail("compensationResidualPredicts.isEmpty") return None } @@ -631,6 +693,7 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) val columnsEquiPredictsResult = rewriteExpressions(Seq(compensationColumnsEquiPredicts.get), swapTableColumn = false, tableMapping, columnMapping, viewProjectList, viewTableAttrs) if (columnsEquiPredictsResult.isEmpty) { + logDetail("columnsEquiPredictsResult.isEmpty") return None } @@ -639,6 +702,7 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) compensationResidualPredicts.get), swapTableColumn = true, tableMapping, queryColumnMapping, viewProjectList, viewTableAttrs) if (otherPredictsResult.isEmpty) { + logDetail("otherPredictsResult.isEmpty") return None } @@ -679,6 +743,10 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) // 3.iterate exprsToRewrite and dfs mapping expression to ViewTableAttributeReference by map val result = exprsToRewrite.map { expr => expr.transform { + case e@Literal(_, _) => + e + case e@Alias(Literal(_, _), _) => + e case e: NamedExpression => val expressionEqual = ExpressionEqual(e) if (viewProjectExprToTableAttr.contains(expressionEqual)) { @@ -690,21 +758,34 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) case e => e } }.asInstanceOf[T] + Some(result) + } - // 4.iterate result and dfs check every AttributeReference in ViewTableAttributeReference - val viewTableAttrsSet = swapTableAttrs.map(_.exprId).toSet - result.foreach { expr => - expr.foreach { - case attr: AttributeReference => - if (!viewTableAttrsSet.contains(attr.exprId)) { - logBasedOnLevel(s"attr:%s cannot found in view:%s" - .format(attr, OmniCachePluginConfig.getConf.curMatchMV)) - return None - } - case _ => + /** + * alias ViewTablePlan's attr by queryPlan's attr + * + * @param viewTablePlan viewTablePlan + * @param queryPlan queryPlan + * @return aliasViewTablePlan + */ + def aliasViewTablePlan(viewTablePlan: LogicalPlan, queryPlan: LogicalPlan): LogicalPlan = { + val viewTableAttrs = viewTablePlan.output + var alias = Map[String, AttributeReference]() + queryPlan.transformAllExpressions { + case attr: AttributeReference => + alias += (attr.sql -> attr) + attr + case e => e + } + val aliasViewTableAttrs = viewTableAttrs.map { attr => + val queryAttr = alias.get(attr.sql) + if (queryAttr.isDefined) { + Alias(attr, queryAttr.get.name)(exprId = queryAttr.get.exprId) + } else { + attr } } - Some(result) + Project(aliasViewTableAttrs, viewTablePlan) } /** @@ -715,19 +796,24 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) * @param originExpressions originExpressions * @return aliasExpressions */ - def aliasExpressions(newExpressions: Seq[NamedExpression], + def aliasExpressions(newExpressions: Seq[Expression], originExpressions: Seq[NamedExpression]): Seq[NamedExpression] = { val result = newExpressions.zip(originExpressions) .map { q => val rewrited = q._1 val origin = q._2 - if (rewrited.exprId == origin.exprId) { - rewrited - } else { - Alias(rewrited, origin.name)(exprId = origin.exprId) + rewrited match { + case r: NamedExpression => + if (r.exprId == origin.exprId) { + rewrited + } else { + Alias(rewrited, origin.name)(exprId = origin.exprId) + } + case _ => + Alias(rewrited, origin.name)(exprId = origin.exprId) } } - result + result.map(_.asInstanceOf[NamedExpression]) } /** @@ -756,7 +842,7 @@ abstract class AbstractMaterializedViewRule(sparkSession: SparkSession) } val aliasedExpressions = aliasExpressions( - rewritedExpressions.get.map(_.asInstanceOf[NamedExpression]).toSeq, originExpressions) + rewritedExpressions.get.toSeq, originExpressions) Some(aliasedExpressions.asInstanceOf[T]) } diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MVRewriteRule.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MVRewriteRule.scala similarity index 33% rename from omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MVRewriteRule.scala rename to omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MVRewriteRule.scala index e198bfcdf3bb9601595ed20c958d5eb814828215..aab66f084051d9bf36a1aa825a1a548c23987ea0 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MVRewriteRule.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MVRewriteRule.scala @@ -18,86 +18,198 @@ package org.apache.spark.sql.catalyst.optimizer.rules import com.fasterxml.jackson.annotation.JsonIgnore -import com.huawei.boostkit.spark.conf.OmniCachePluginConfig -import com.huawei.boostkit.spark.util.{RewriteHelper, RewriteLogger} +import com.huawei.boostkit.spark.conf.OmniMVPluginConfig +import com.huawei.boostkit.spark.util.{RewriteHelper, RewriteLogger, ViewMetadata} +import com.huawei.boostkit.spark.util.ViewMetadata._ +import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.command.OmniCacheCreateMvCommand +import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand +import org.apache.spark.sql.hive.execution.{CreateHiveTableAsSelectCommand, InsertIntoHiveTable, OptimizedCreateHiveTableAsSelectCommand} import org.apache.spark.status.ElementTrackingStore import org.apache.spark.util.kvstore.KVIndex -class MVRewriteRule(session: SparkSession) extends Rule[LogicalPlan] with RewriteLogger { - val omniCacheConf: OmniCachePluginConfig = OmniCachePluginConfig.getConf +class MVRewriteRule(session: SparkSession) + extends Rule[LogicalPlan] with RewriteHelper with RewriteLogger { + private var cannotRewritePlans: Set[LogicalPlan] = Set[LogicalPlan]() - val joinRule = new MaterializedViewJoinRule(session) - val aggregateRule = new MaterializedViewAggregateRule(session) + private val omniMVConf: OmniMVPluginConfig = OmniMVPluginConfig.getConf + + private val joinRule = new MaterializedViewJoinRule(session) + private val outJoinRule = new MaterializedViewOutJoinRule(session) + private val aggregateRule = new MaterializedViewAggregateRule(session) + private val outJoinAggregateRule = new MaterializedViewOutJoinAggregateRule(session) override def apply(logicalPlan: LogicalPlan): LogicalPlan = { - if (!omniCacheConf.enableOmniCache) { + if (!omniMVConf.enableOmniMV || cannotRewritePlans.contains(logicalPlan)) { return logicalPlan } + RewriteHelper.disableSqlLog() try { logicalPlan match { - case _: OmniCacheCreateMvCommand => + case _: CreateHiveTableAsSelectCommand => + tryRewritePlan(logicalPlan) + case _: OptimizedCreateHiveTableAsSelectCommand => + tryRewritePlan(logicalPlan) + case _: InsertIntoHadoopFsRelationCommand => + tryRewritePlan(logicalPlan) + case _: InsertIntoHiveTable => + tryRewritePlan(logicalPlan) + case _: Command => logicalPlan case _ => tryRewritePlan(logicalPlan) } } catch { case e: Throwable => - logWarning(s"Failed to rewrite plan with mv,errmsg: ${e.getMessage}") + logError(s"Failed to rewrite plan with mv.") logicalPlan + } finally { + RewriteHelper.enableSqlLog() } } def tryRewritePlan(plan: LogicalPlan): LogicalPlan = { - val usingMvs = mutable.Set.empty[String] + val usingMvInfos = mutable.Set.empty[(String, String)] RewriteTime.clear() val rewriteStartSecond = System.currentTimeMillis() - val res = plan.transformDown { - case p: Project => - joinRule.perform(Some(p), p.child, usingMvs) - case a: Aggregate => - var rewritedPlan = aggregateRule.perform(None, a, usingMvs) - // below agg may be join/filter can be rewrite - if (rewritedPlan == a) { - val child = Project( - RewriteHelper.extractAllAttrsFromExpression(a.aggregateExpressions).toSeq, a.child) - val rewritedChild = joinRule.perform(Some(child), child.child, usingMvs) - if (rewritedChild != child) { - val projectChild = rewritedChild.asInstanceOf[Project] - rewritedPlan = a.copy(child = Project( - projectChild.projectList ++ projectChild.child.output, projectChild.child)) - } + + if (ViewMetadata.status == ViewMetadata.STATUS_LOADING) { + return plan + } + // init viewMetadata by full queryPlan + RewriteTime.withTimeStat("viewMetadata") { + ViewMetadata.init(session, Some(plan)) + } + + // automatic wash out + if (OmniMVPluginConfig.getConf.enableAutoWashOut) { + val autoCheckInterval: Long = RewriteHelper.secondsToMillisecond( + OmniMVPluginConfig.getConf.autoCheckWashOutTimeInterval) + val autoWashOutTime: Long = ViewMetadata.autoWashOutTimestamp.getOrElse(0) + if ((System.currentTimeMillis() - autoWashOutTime) >= autoCheckInterval) { + automaticWashOutCheck() + } + } + + var res = RewriteHelper.optimizePlan(plan) + val queryTables = extractTablesOnly(res).toSet + val candidateViewPlans = RewriteTime.withTimeStat("getApplicableMaterializations") { + getApplicableMaterializations(queryTables) + .filter(x => !OmniMVPluginConfig.isMVInUpdate(x._2)) + } + + if (candidateViewPlans.isEmpty) { + logDetail(s"no candidateViewPlans") + } else { + for (candidateViewPlan <- candidateViewPlans) { + res = res.transformDown { + case r => + if (RewriteHelper.containsMV(r)) { + r + } else { + r match { + case p: Project => + if (containsOuterJoin(p)) { + outJoinRule.perform(Some(p), p.child, usingMvInfos, candidateViewPlan) + } else { + joinRule.perform(Some(p), p.child, usingMvInfos, candidateViewPlan) + } + case a: Aggregate => + var rewritedPlan = if (containsOuterJoin(a)) { + outJoinAggregateRule.perform(None, a, usingMvInfos, candidateViewPlan) + } else { + aggregateRule.perform(None, a, usingMvInfos, candidateViewPlan) + } + // below agg may be join/filter can be rewrite + if (rewritedPlan == a && !a.child.isInstanceOf[Project]) { + val child = Project( + RewriteHelper.extractAllAttrsFromExpression( + a.aggregateExpressions).toSeq, a.child) + val rewritedChild = if (containsOuterJoin(a)) { + outJoinRule.perform(Some(child), child.child, usingMvInfos, candidateViewPlan) + } else { + joinRule.perform(Some(child), child.child, usingMvInfos, candidateViewPlan) + } + if (rewritedChild != child) { + val projectChild = rewritedChild.asInstanceOf[Project] + rewritedPlan = a.copy(child = Project( + projectChild.projectList ++ projectChild.child.output, projectChild.child)) + } + } + rewritedPlan + case p => p + } + } } - rewritedPlan - case p => p + } } - if (usingMvs.nonEmpty) { + + RewriteTime.queue.add(("load_mv.nums", ViewMetadata.viewToTablePlan.size())) + if (usingMvInfos.nonEmpty) { RewriteTime.withTimeStat("checkAttrsValid") { if (!RewriteHelper.checkAttrsValid(res)) { + RewriteTime.statFromStartTime("total", rewriteStartSecond) + logBasedOnLevel(RewriteTime.stat()) return plan } } val sql = session.sparkContext.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) - val mvs = usingMvs.mkString(";").replaceAll("`", "") + val mvs = usingMvInfos.mkString(";").replaceAll("`", "") val costSecond = (System.currentTimeMillis() - rewriteStartSecond).toString val log = ("logicalPlan MVRewrite success," + - "using materialized view:[%s],cost %s milliseconds,original sql:%s") - .format(mvs, costSecond, sql) + "using materialized view:[%s],cost %s milliseconds,") + .format(mvs, costSecond) logBasedOnLevel(log) session.sparkContext.listenerBus.post(SparkListenerMVRewriteSuccess(sql, mvs)) + } else { + res = plan + cannotRewritePlans += res } RewriteTime.statFromStartTime("total", rewriteStartSecond) - logBasedOnLevel(RewriteTime.timeStat.toString()) + logBasedOnLevel(RewriteTime.stat()) res } + + def containsOuterJoin(plan: LogicalPlan): Boolean = { + plan.foreach { + case j: Join => + j.joinType match { + case LeftOuter => return true + case RightOuter => return true + case FullOuter => return true + case LeftSemi => return true + case LeftAnti => return true + case _ => + } + case _ => + } + false + } + + private def automaticWashOutCheck(): Unit = { + val timeInterval = OmniMVPluginConfig.getConf.autoWashOutTimeInterval + val threshold = System.currentTimeMillis() - RewriteHelper.daysToMillisecond(timeInterval) + val viewQuantity = OmniMVPluginConfig.getConf.automaticWashOutMinimumViewQuantity + + loadViewCount() + loadWashOutTimestamp() + + if (ViewMetadata.viewCnt.size() >= viewQuantity && + (ViewMetadata.washOutTimestamp.isEmpty || + (ViewMetadata.washOutTimestamp.get <= threshold))) { + ViewMetadata.spark.sql("WASH OUT MATERIALIZED VIEW") + logInfo("WASH OUT MATERIALIZED VIEW BY AUTOMATICALLY.") + ViewMetadata.autoWashOutTimestamp = Some(System.currentTimeMillis()) + } + } } @DeveloperApi @@ -121,13 +233,15 @@ class MVRewriteSuccessListener( object RewriteTime { val timeStat: mutable.Map[String, Long] = mutable.HashMap[String, Long]() + val queue = new LinkedBlockingQueue[(String, Long)]() def statFromStartTime(key: String, startTime: Long): Unit = { - timeStat += (key -> (timeStat.getOrElse(key, 0L) + System.currentTimeMillis() - startTime)) + queue.add((key, System.currentTimeMillis() - startTime)) } def clear(): Unit = { timeStat.clear() + queue.clear() } def withTimeStat[T](key: String)(f: => T): T = { @@ -138,4 +252,16 @@ object RewriteTime { statFromStartTime(key, startTime) } } + + def stat(): String = { + queue.forEach { infos => + val (key, time) = infos + if (key.endsWith(".C")) { + timeStat += (key -> Math.max(timeStat.getOrElse(key, 0L), time)) + } else { + timeStat += (key -> (timeStat.getOrElse(key, 0L) + time)) + } + } + s"plugin cost:${timeStat.toSeq.sortWith((a, b) => a._2 > b._2).toString()}" + } } diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewAggregateRule.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewAggregateRule.scala similarity index 59% rename from omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewAggregateRule.scala rename to omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewAggregateRule.scala index 1edce0eae1ebee0362092fd71be938cc4176b661..abc5c6fb6c6fc26ed0235052b22e629331bb4d46 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewAggregateRule.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewAggregateRule.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer.rules import com.google.common.collect.BiMap -import com.huawei.boostkit.spark.util.{ExpressionEqual, TableEqual} +import com.huawei.boostkit.spark.util.{ExpressionEqual, RewriteHelper, TableEqual} import scala.collection.mutable import org.apache.spark.sql.SparkSession @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.DecimalType class MaterializedViewAggregateRule(sparkSession: SparkSession) @@ -55,7 +56,7 @@ class MaterializedViewAggregateRule(sparkSession: SparkSession) override def compensateViewPartial(viewTablePlan: LogicalPlan, viewQueryPlan: LogicalPlan, topViewProject: Option[Project], - needTables: Set[TableEqual]): + needTables: Seq[TableEqual]): Option[(LogicalPlan, LogicalPlan, Option[Project])] = { // newViewTablePlan var newViewTablePlan = viewTablePlan @@ -144,7 +145,7 @@ class MaterializedViewAggregateRule(sparkSession: SparkSession) // if subGroupExpressionEquals is empty and aggCalls all in viewAggExpressionEquals, // final need project not aggregate val isJoinCompensated = viewTablePlan.isInstanceOf[Join] - var projectFlag = subGroupExpressionEquals.isEmpty && !isJoinCompensated + val projectFlag = subGroupExpressionEquals.isEmpty && !isJoinCompensated // 3.1.viewGroupExpressionEquals is same to queryGroupExpressionEquals if (projectFlag) { @@ -169,7 +170,12 @@ class MaterializedViewAggregateRule(sparkSession: SparkSession) // such as max(c1),min(c1),sum(c1),avg(c1),count(distinct c1), // if c1 in view,it can support } else { - return None + expr match { + case Literal(_, _) | Alias(Literal(_, _), _) => + case _ => + logDetail(s"expr:$expr cannot found in viewQueryPlan") + return None + } } newQueryAggExpressions :+= expr.asInstanceOf[NamedExpression] } @@ -177,35 +183,68 @@ class MaterializedViewAggregateRule(sparkSession: SparkSession) } else { queryAggExpressionEquals.foreach { aggCall => var expr = aggCall.expression - expr match { - case Alias(AggregateExpression(_, _, isDistinct, _, _), _) => - if (isDistinct) { - return None - } - case _ => - } // rollUp and use viewTableAttr if (viewAggExpressionEquals.contains(aggCall)) { val viewTableAttr = viewTableAttrs(viewAggExpressionEqualsOrdinal(aggCall)) .asInstanceOf[AttributeReference] val qualifier = viewTableAttr.qualifier expr = expr match { - case a@Alias(agg@AggregateExpression(Sum(_), _, _, _, _), _) => - copyAlias(a, agg.copy(aggregateFunction = Sum(viewTableAttr)), qualifier) + case a@Alias(agg@AggregateExpression(Sum(_), _, isDistinct, _, _), _) => + if (isDistinct) { + return None + } + agg.resultAttribute match { + case DecimalType.Expression(prec, scale) => + copyAlias(a, MakeDecimal(agg.copy(aggregateFunction = + Sum(UnscaledValue(viewTableAttr))), prec, scale), qualifier) + case _ => + copyAlias(a, agg.copy(aggregateFunction = Sum(viewTableAttr)), qualifier) + } case a@Alias(agg@AggregateExpression(Min(_), _, _, _, _), _) => copyAlias(a, agg.copy(aggregateFunction = Min(viewTableAttr)), qualifier) case a@Alias(agg@AggregateExpression(Max(_), _, _, _, _), _) => copyAlias(a, agg.copy(aggregateFunction = Max(viewTableAttr)), qualifier) - case a@Alias(agg@AggregateExpression(Count(_), _, _, _, _), _) => + case a@Alias(agg@AggregateExpression(Count(_), _, isDistinct, _, _), _) => + if (isDistinct) { + return None + } copyAlias(a, agg.copy(aggregateFunction = Sum(viewTableAttr)), qualifier) case a@Alias(AttributeReference(_, _, _, _), _) => copyAlias(a, viewTableAttr, viewTableAttr.qualifier) + case a@Alias(agg@AggregateExpression(Average(child), _, isDistinct, _, _), _) => + if (isDistinct) { + return None + } + val count = ExpressionEqual(agg.copy(aggregateFunction = Count(child))) + if (viewAggExpressionEquals.contains(count)) { + val countAttr = viewTableAttrs(viewAggExpressionEqualsOrdinal(count)) + .asInstanceOf[AttributeReference] + copyAlias(a, Divide( + agg.copy(aggregateFunction = Sum(Multiply(viewTableAttr, countAttr)), + resultId = NamedExpression.newExprId), + agg.copy(aggregateFunction = Sum(countAttr), + resultId = NamedExpression.newExprId)), + qualifier) + } else { + return None + } + case Alias(AggregateExpression(_, _, _, _, _), _) => + return None case AttributeReference(_, _, _, _) => viewTableAttr + case Literal(_, _) | Alias(Literal(_, _), _) => + expr + case a@Alias(_, _) => + copyAlias(a, viewTableAttr, qualifier) // other agg like avg or user_defined udaf not support rollUp case _ => return None } } else { - return None + expr match { + case Literal(_, _) | Alias(Literal(_, _), _) => + case _ => + logDetail(s"expr:$expr cannot found in viewQueryPlan") + return None + } } newQueryAggExpressions :+= expr.asInstanceOf[NamedExpression] } @@ -223,21 +262,51 @@ class MaterializedViewAggregateRule(sparkSession: SparkSession) } // 5.add project - if (projectFlag) { - // 5.1.not need agg,just project - Some(Project(rewritedQueryAggExpressions.get, viewTablePlan)) - } else { - // 5.2.need agg,rewrite GroupingExpressions and new agg - val rewritedGroupingExpressions = rewriteAndAliasExpressions(newGroupingExpressions, - swapTableColumn = true, tableMapping, columnMapping, - viewProjectList, viewTableAttrs, - newGroupingExpressions.map(_.asInstanceOf[NamedExpression])) - if (rewritedGroupingExpressions.isEmpty) { - return None + val res = + if (projectFlag) { + // 5.1.not need agg,just project + Some(Project(rewritedQueryAggExpressions.get, viewTablePlan)) + } else { + // cast function to alias(NamedExpression) + newGroupingExpressions = newGroupingExpressions.map { + case attr: AttributeReference => + attr + case alias: Alias => + alias + case e => + Alias(e, e.prettyName)() + } + // 5.2.need agg,rewrite GroupingExpressions and new agg + val rewritedGroupingExpressions = rewriteAndAliasExpressions(newGroupingExpressions, + swapTableColumn = true, tableMapping, columnMapping, + viewProjectList, viewTableAttrs, + newGroupingExpressions.map(_.asInstanceOf[NamedExpression])) + if (rewritedGroupingExpressions.isEmpty) { + return None + } + + var rewritedGroupingExpressionsRes = rewritedGroupingExpressions.get + val rewritedGroupingExpressionsSet = rewritedGroupingExpressionsRes + .map(ExpressionEqual).toSet + rewritedQueryAggExpressions.get.foreach { + case alias@Alias(AttributeReference(_, _, _, _), _) => + if (!rewritedGroupingExpressionsSet.contains(ExpressionEqual(alias))) { + rewritedGroupingExpressionsRes +:= alias + } + case attr@AttributeReference(_, _, _, _) => + if (!rewritedGroupingExpressionsSet.contains(ExpressionEqual(attr))) { + rewritedGroupingExpressionsRes +:= attr + } + case _ => + } + + Some(Aggregate(rewritedGroupingExpressionsRes, + rewritedQueryAggExpressions.get, viewTablePlan)) } - Some(Aggregate(rewritedGroupingExpressions.get, - rewritedQueryAggExpressions.get, viewTablePlan)) + if (!RewriteHelper.checkAttrsValid(res.get)) { + return None } + res } def copyAlias(alias: Alias, child: Expression, qualifier: Seq[String]): Alias = { @@ -246,3 +315,66 @@ class MaterializedViewAggregateRule(sparkSession: SparkSession) nonInheritableMetadataKeys = alias.nonInheritableMetadataKeys) } } + +class MaterializedViewOutJoinAggregateRule(sparkSession: SparkSession) + extends MaterializedViewAggregateRule(sparkSession: SparkSession) { + + /** + * check plan if match current rule + * + * @param logicalPlan LogicalPlan + * @return true:matched ; false:unMatched + */ + override def isValidPlan(logicalPlan: LogicalPlan): Boolean = { + if (!logicalPlan.isInstanceOf[Aggregate]) { + return false + } + logicalPlan.children.forall(isValidOutJoinLogicalPlan) + } + + /** + * queryTableInfo!=viewTableInfo , need do join compensate + * + * @param viewTablePlan viewTablePlan + * @param viewQueryPlan viewQueryPlan + * @param topViewProject topViewProject + * @param needTables needTables + * @return join compensated viewTablePlan + */ + override def compensateViewPartial(viewTablePlan: LogicalPlan, + viewQueryPlan: LogicalPlan, + topViewProject: Option[Project], + needTables: Seq[TableEqual]): + Option[(LogicalPlan, LogicalPlan, Option[Project])] = { + Some(viewTablePlan, viewQueryPlan, None) + } + + /** + * use viewTablePlan(join compensated) ,query project , + * compensationPredicts to rewrite final plan + * + * @param viewTablePlan viewTablePlan(join compensated) + * @param viewQueryPlan viewQueryPlan + * @param queryPlan queryPlan + * @param tableMapping tableMapping + * @param columnMapping columnMapping + * @param viewProjectList viewProjectList + * @param viewTableAttrs viewTableAttrs + * @return final plan + */ + override def rewriteView(viewTablePlan: LogicalPlan, viewQueryPlan: LogicalPlan, + queryPlan: LogicalPlan, tableMapping: BiMap[String, String], + columnMapping: Map[ExpressionEqual, mutable.Set[ExpressionEqual]], + viewProjectList: Seq[Expression], viewTableAttrs: Seq[Attribute]): + Option[LogicalPlan] = { + val simplifiedViewPlanString = simplifiedPlanString( + findOriginExpression(viewQueryPlan), ALL_CONDITION) + val simplifiedQueryPlanString = simplifiedPlanString( + findOriginExpression(queryPlan), ALL_CONDITION) + if (simplifiedQueryPlanString != simplifiedViewPlanString) { + return None + } + super.rewriteView(viewTablePlan, viewQueryPlan, queryPlan, + tableMapping, columnMapping, viewProjectList, viewTableAttrs) + } +} diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewJoinRule.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewJoinRule.scala similarity index 89% rename from omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewJoinRule.scala rename to omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewJoinRule.scala index 5c7c477dda7bf2f03dca11d212222b7fcf75c829..e059272892b45df79f1a9e63dc3838fa02ad5296 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewJoinRule.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewJoinRule.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer.rules import com.google.common.collect.BiMap -import com.huawei.boostkit.spark.util.{ExpressionEqual, TableEqual} +import com.huawei.boostkit.spark.util.{ExpressionEqual, RewriteHelper, TableEqual} import scala.collection.mutable import org.apache.spark.sql.SparkSession @@ -50,7 +50,7 @@ class MaterializedViewJoinRule(sparkSession: SparkSession) override def compensateViewPartial(viewTablePlan: LogicalPlan, viewQueryPlan: LogicalPlan, topViewProject: Option[Project], - needTables: Set[TableEqual]): + needTables: Seq[TableEqual]): Option[(LogicalPlan, LogicalPlan, Option[Project])] = { // newViewTablePlan var newViewTablePlan = viewTablePlan @@ -102,7 +102,9 @@ class MaterializedViewJoinRule(sparkSession: SparkSession) // queryProjectList val queryProjectList = extractTopProjectList(queryPlan).map(_.asInstanceOf[NamedExpression]) - val swapQueryProjectList = swapColumnReferences(queryProjectList, columnMapping) + val origins = generateOrigins(queryPlan) + val originQueryProjectList = queryProjectList.map(x => findOriginExpression(origins, x)) + val swapQueryProjectList = swapColumnReferences(originQueryProjectList, columnMapping) // rewrite and alias queryProjectList // if the rewrite expression exprId != origin expression exprId, @@ -111,11 +113,9 @@ class MaterializedViewJoinRule(sparkSession: SparkSession) swapTableColumn = true, tableMapping, columnMapping, viewProjectList, viewTableAttrs, queryProjectList) - if (rewritedQueryProjectList.isEmpty) { - return None - } - + val res = Project(rewritedQueryProjectList.get + .map(_.asInstanceOf[NamedExpression]), viewTablePlan) // add project - Some(Project(rewritedQueryProjectList.get, viewTablePlan)) + Some(res) } } diff --git a/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewOutJoinRule.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewOutJoinRule.scala new file mode 100644 index 0000000000000000000000000000000000000000..d2720f1981d24c32028446dbed1d1637a31f8ef8 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewOutJoinRule.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer.rules + +import com.google.common.collect.BiMap +import com.huawei.boostkit.spark.util._ +import scala.collection.mutable + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.PushDownPredicates +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} + +class MaterializedViewOutJoinRule(sparkSession: SparkSession) + extends AbstractMaterializedViewRule(sparkSession: SparkSession) { + + /** + * check plan if match current rule + * + * @param logicalPlan LogicalPlan + * @return true:matched ; false:unMatched + */ + def isValidPlan(logicalPlan: LogicalPlan): Boolean = { + isValidOutJoinLogicalPlan(logicalPlan) + } + + /** + * queryTableInfo!=viewTableInfo , need do join compensate + * + * @param viewTablePlan viewTablePlan + * @param viewQueryPlan viewQueryPlan + * @param topViewProject topViewProject + * @param needTables needTables + * @return join compensated viewTablePlan + */ + override def compensateViewPartial(viewTablePlan: LogicalPlan, + viewQueryPlan: LogicalPlan, + topViewProject: Option[Project], + needTables: Seq[TableEqual]): + Option[(LogicalPlan, LogicalPlan, Option[Project])] = { + Some(viewTablePlan, viewQueryPlan, None) + } + + /** + * compute compensationPredicates between viewQueryPlan and mappedQueryPlan + * + * @param viewTablePlan viewTablePlan + * @param queryPredict queryPredict + * @param viewPredict viewPredict + * @param tableMapping tableMapping + * @param columnMapping columnMapping + * @param viewProjectList viewProjectList + * @param viewTableAttrs viewTableAttrs + * @return predictCompensationPlan + */ + override def computeCompensationPredicates(viewTablePlan: LogicalPlan, + queryPredict: (EquivalenceClasses, Seq[ExpressionEqual], + Seq[ExpressionEqual]), + viewPredict: (EquivalenceClasses, Seq[ExpressionEqual], + Seq[ExpressionEqual]), + tableMapping: BiMap[String, String], + columnMapping: Map[ExpressionEqual, mutable.Set[ExpressionEqual]], + viewProjectList: Seq[Expression], viewTableAttrs: Seq[Attribute]): + Option[LogicalPlan] = { + Some(viewTablePlan) + } + + /** + * We map every table in the query to a table with the same qualified + * name (all query tables are contained in the view, thus this is equivalent + * to mapping every table in the query to a view table). + * + * @param queryTables queryTables + * @return + */ + override def generateTableMappings(queryTables: Set[TableEqual]): Seq[BiMap[String, String]] = { + // skipSwapTable + Seq(EMPTY_BIMAP) + } + + /** + * use viewTablePlan(join compensated) ,query project , + * compensationPredicts to rewrite final plan + * + * @param viewTablePlan viewTablePlan(join compensated) + * @param viewQueryPlan viewQueryPlan + * @param queryPlan queryPlan + * @param tableMapping tableMapping + * @param columnMapping columnMapping + * @param viewProjectList viewProjectList + * @param viewTableAttrs viewTableAttrs + * @return final plan + */ + override def rewriteView(viewTablePlan: LogicalPlan, viewQueryPlan: LogicalPlan, + queryPlan: LogicalPlan, tableMapping: BiMap[String, String], + columnMapping: Map[ExpressionEqual, mutable.Set[ExpressionEqual]], + viewProjectList: Seq[Expression], viewTableAttrs: Seq[Attribute]): + Option[LogicalPlan] = { + + val queryOrigins = generateOrigins(queryPlan) + + val viewTableAttrsSet = viewTableAttrs.toSet + val viewOrigins = generateOrigins(viewQueryPlan) + val originViewProjectList = viewProjectList.map(x => findOriginExpression(viewOrigins, x)) + val simplifiedViewPlanString = + simplifiedPlanString(findOriginExpression(viewOrigins, viewQueryPlan), OUTER_JOIN_CONDITION) + + // Push down the topmost filter condition for simplifiedPlanString() matching. + val pushDownQueryPLan = PushDownPredicates.apply(queryPlan) + var rewritten = false + val res = pushDownQueryPLan.transform { + case curPlan: Join => + val simplifiedQueryPlanString = simplifiedPlanString( + findOriginExpression(queryOrigins, curPlan), OUTER_JOIN_CONDITION) + if (simplifiedQueryPlanString == simplifiedViewPlanString) { + // Predicate compensation for matching execution plans. + val viewExpr = extractPredictExpressions( + findOriginExpression(viewOrigins, viewQueryPlan), EMPTY_BIMAP, COMPENSABLE_CONDITION) + val queryExpr = extractPredictExpressions( + findOriginExpression(queryOrigins, curPlan), EMPTY_BIMAP, COMPENSABLE_CONDITION) + val compensatedViewTablePlan = super.computeCompensationPredicates(viewTablePlan, + queryExpr, viewExpr, tableMapping, columnMapping, + extractTopProjectList(viewQueryPlan), viewTablePlan.output) + + if (compensatedViewTablePlan.isEmpty) { + curPlan + } else { + rewritten = true + val (curProject: Project, _) = extractTables(Project(curPlan.output, curPlan)) + val curProjectList = curProject.projectList + .map(x => findOriginExpression(queryOrigins, x).asInstanceOf[NamedExpression]) + val swapCurProjectList = swapColumnReferences(curProjectList, columnMapping) + val rewritedQueryProjectList = rewriteAndAliasExpressions(swapCurProjectList, + swapTableColumn = true, tableMapping, columnMapping, + originViewProjectList, viewTableAttrs, curProjectList) + + Project(rewritedQueryProjectList.get + .filter(x => isValidExpression(x, viewTableAttrsSet)) + ++ viewTableAttrs.map(_.asInstanceOf[NamedExpression]) + , compensatedViewTablePlan.get) + } + } else { + curPlan + } + case p => p + } + if (rewritten) { + Some(res) + } else { + None + } + } +} diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/parser/OmniCacheExtensionAstBuilder.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/parser/OmniMVExtensionAstBuilder.scala similarity index 85% rename from omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/parser/OmniCacheExtensionAstBuilder.scala rename to omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/parser/OmniMVExtensionAstBuilder.scala index 1b9ead1fb97f8c08ceb680abc5bfadba16cace90..88bc1995f456c75c3d61ebd9305f4b1054b81952 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/parser/OmniCacheExtensionAstBuilder.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/parser/OmniMVExtensionAstBuilder.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.catalyst.parser -import com.huawei.boostkit.spark.conf.OmniCachePluginConfig -import com.huawei.boostkit.spark.conf.OmniCachePluginConfig._ +import com.huawei.boostkit.spark.conf.OmniMVPluginConfig +import com.huawei.boostkit.spark.conf.OmniMVPluginConfig._ import com.huawei.boostkit.spark.util.{RewriteHelper, RewriteLogger} import org.antlr.v4.runtime.ParserRuleContext import org.antlr.v4.runtime.tree.{ParseTree, RuleNode} @@ -27,7 +27,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.{SQLConfHelper, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.parser.OmniCacheSqlExtensionsParser._ +import org.apache.spark.sql.catalyst.parser.OmniMVSqlExtensionsParser._ import org.apache.spark.sql.catalyst.parser.ParserUtils._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.execution._ @@ -39,11 +39,11 @@ import org.apache.spark.sql.execution.datasources._ * * @param delegate Spark default ParserInterface */ -class OmniCacheExtensionAstBuilder(spark: SparkSession, delegate: ParserInterface) - extends OmniCacheSqlExtensionsBaseVisitor[AnyRef] with SQLConfHelper with RewriteLogger { +class OmniMVExtensionAstBuilder(spark: SparkSession, delegate: ParserInterface) + extends OmniMVSqlExtensionsBaseVisitor[AnyRef] with SQLConfHelper with RewriteLogger { /** - * Parse CreateMVContext to OmniCacheCreateMvCommand + * Parse CreateMVContext to OmniMVCreateMvCommand * * @param ctx the parse tree * */ @@ -70,14 +70,14 @@ class OmniCacheExtensionAstBuilder(spark: SparkSession, delegate: ParserInterfac try { val provider = - OmniCachePluginConfig.getConf.defaultDataSource + OmniMVPluginConfig.getConf.defaultDataSource RewriteHelper.disableCachePlugin() val qe = spark.sql(query).queryExecution val logicalPlan = qe.optimizedPlan if (RewriteHelper.containsMV(qe.analyzed)) { throw new RuntimeException("not support create mv from mv") } - OmniCacheCreateMvCommand(databaseName, name, provider, comment, properties, + OmniMVCreateMvCommand(databaseName, name, provider, comment, properties, ifNotExists, partCols, logicalPlan, logicalPlan.output.map(_.name)) } catch { case e: Throwable => @@ -87,12 +87,12 @@ class OmniCacheExtensionAstBuilder(spark: SparkSession, delegate: ParserInterfac } /** - * Parse CreateMVHeaderContext to OmniCacheHeader + * Parse CreateMVHeaderContext to OmniMVHeader * * @param ctx the parse tree * */ override def visitCreateMVHeader(ctx: CreateMVHeaderContext - ): OmniCacheHeader = withOrigin(ctx) { + ): OmniMVHeader = withOrigin(ctx) { val ifNotExists = ctx.EXISTS() != null val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText) (multipartIdentifier, ifNotExists) @@ -115,7 +115,7 @@ class OmniCacheExtensionAstBuilder(spark: SparkSession, delegate: ParserInterfac s"Table or view not found: $tableIdentifier .") } - var catalogTable = spark.sessionState.catalog.getTableMetadata(tableIdentifier) + val catalogTable = spark.sessionState.catalog.getTableMetadata(tableIdentifier) val queryStr = catalogTable.properties.get(MV_QUERY_ORIGINAL_SQL) if (queryStr.isEmpty) { throw new RuntimeException("cannot refresh a table with refresh mv") @@ -130,6 +130,16 @@ class OmniCacheExtensionAstBuilder(spark: SparkSession, delegate: ParserInterfac } try { spark.sessionState.catalogManager.setCurrentNamespace(Array(curDatabase)) + val fileIndex = spark + .sql(queryStr.get) + .queryExecution + .sparkPlan + .collect { + case FileSourceScanExec(relation, _, _, _, _, _, _, _, _) + => relation.location + case RowDataSourceScanExec(_, _, _, _, _, relation: HadoopFsRelation, _) + => relation.location + } // disable plugin RewriteHelper.disableCachePlugin() val data = spark.sql(queryStr.get).queryExecution.optimizedPlan @@ -155,11 +165,6 @@ class OmniCacheExtensionAstBuilder(spark: SparkSession, delegate: ParserInterfac val partitionColumns = catalogTable.partitionColumnNames PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive) - val fileIndex = Some(catalogTable.identifier).map { tableIdent => - spark.table(tableIdent).queryExecution.analyzed.collect { - case LogicalRelation(t: HadoopFsRelation, _, _, _) => t.location - }.head - } // For partitioned relation r, r.schema's column ordering can be different from the column // ordering of data.logicalPlan (partition columns are all moved after data column). This // will be adjusted within InsertIntoHadoopFsRelation. @@ -302,7 +307,7 @@ class OmniCacheExtensionAstBuilder(spark: SparkSession, delegate: ParserInterfac /** * alias tuple2 */ - type OmniCacheHeader = (Seq[String], Boolean) + type OmniMVHeader = (Seq[String], Boolean) /** * Create a comment string. @@ -311,7 +316,7 @@ class OmniCacheExtensionAstBuilder(spark: SparkSession, delegate: ParserInterfac string(ctx.STRING) } - protected def typedVisit[T](ctx: ParseTree): T = { + private def typedVisit[T](ctx: ParseTree): T = { ctx.accept(this).asInstanceOf[T] } @@ -323,7 +328,7 @@ class OmniCacheExtensionAstBuilder(spark: SparkSession, delegate: ParserInterfac /** * Create an optional comment string. */ - protected def visitCommentSpecList(ctx: CommentSpecContext): Option[String] = { + private def visitCommentSpecList(ctx: CommentSpecContext): Option[String] = { Option(ctx).map(visitCommentSpec) } @@ -348,4 +353,33 @@ class OmniCacheExtensionAstBuilder(spark: SparkSession, delegate: ParserInterfac override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { ctx.ident.asScala.map(_.getText) } + + override def visitWashOutMV(ctx: WashOutMVContext): LogicalPlan = { + val dropAll = if (ctx.ALL() == null) false else true + val strategy = if (ctx.washOutExpressions() != null) { + visitWashOutExpressions(ctx.washOutExpressions()) + } else { + Option.empty + } + + WashOutMaterializedViewCommand(dropAll, strategy) + } + + override def visitWashOutStrategy(ctx: WashOutStrategyContext): (String, Option[Int]) = { + val key = ctx.children.get(0).getText + if (ctx.children.size() >= 2) { + (key, Some(ctx.children.get(1).getText.toInt)) + } else { + (key, Option.empty) + } + } + + override def visitWashOutExpressions( + ctx: WashOutExpressionsContext): Option[List[(String, Option[Int])]] = withOrigin(ctx) { + if (ctx.washOutStrategy() != null) { + Some(ctx.washOutStrategy().asScala.map(visitWashOutStrategy).toList) + } else { + Option.empty + } + } } diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/parser/OmniCacheExtensionSqlParser.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/parser/OmniMVExtensionSqlParser.scala similarity index 85% rename from omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/parser/OmniCacheExtensionSqlParser.scala rename to omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/parser/OmniMVExtensionSqlParser.scala index bd99f82a30295edc176c03711a1126a527070ab0..a47f9f33caaf3833e68f50966b716b4ba365bf9d 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/parser/OmniCacheExtensionSqlParser.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/catalyst/parser/OmniMVExtensionSqlParser.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.parser +import com.huawei.boostkit.spark.conf.OmniMVPluginConfig import com.huawei.boostkit.spark.util.RewriteLogger import java.util.Locale import org.antlr.v4.runtime._ @@ -30,12 +31,15 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.types.{DataType, StructType} -class OmniCacheExtensionSqlParser(spark: SparkSession, +class OmniMVExtensionSqlParser(spark: SparkSession, delegate: ParserInterface) extends ParserInterface with SQLConfHelper with RewriteLogger { - lazy val astBuilder = new OmniCacheExtensionAstBuilder(spark, delegate) + private lazy val astBuilder = new OmniMVExtensionAstBuilder(spark, delegate) override def parsePlan(sqlText: String): LogicalPlan = { + if (OmniMVPluginConfig.getConf.enableSqlLog) { + spark.sparkContext.setJobDescription(sqlText) + } if (isMaterializedViewCommand(sqlText)) { val plan = parse(sqlText) { parser => astBuilder.visit(parser.singleStatement()).asInstanceOf[LogicalPlan] @@ -70,25 +74,26 @@ class OmniCacheExtensionSqlParser(spark: SparkSession, delegate.parseDataType(sqlText) } - def isMaterializedViewCommand(sqlText: String): Boolean = { + private def isMaterializedViewCommand(sqlText: String): Boolean = { val normalized = sqlText.toLowerCase(Locale.ROOT).trim().replaceAll("\\s+", " ") normalized.contains("show materialized views") || normalized.contains("create materialized view") || normalized.contains("drop materialized view") || normalized.contains("alter materialized view") || - normalized.contains("refresh materialized view") + normalized.contains("refresh materialized view") || + (normalized.contains("wash out") && normalized.contains("materialized view")) } - def parse[T](command: String)(toResult: OmniCacheSqlExtensionsParser => T): T = { + def parse[T](command: String)(toResult: OmniMVSqlExtensionsParser => T): T = { logDebug(s"Parsing command: $command") - val lexer = new OmniCacheSqlExtensionsLexer( + val lexer = new OmniMVSqlExtensionsLexer( new UpperCaseCharStream(CharStreams.fromString(command))) lexer.removeErrorListeners() lexer.addErrorListener(ParseErrorListener) val tokenStream = new CommonTokenStream(lexer) - val parser = new OmniCacheSqlExtensionsParser(tokenStream) + val parser = new OmniMVSqlExtensionsParser(tokenStream) parser.addParseListener(PostProcessor) parser.removeErrorListeners() parser.addErrorListener(ParseErrorListener) diff --git a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/execution/command/OmniCacheCommand.scala b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/execution/command/OmniMVCommand.scala similarity index 74% rename from omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/execution/command/OmniCacheCommand.scala rename to omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/execution/command/OmniMVCommand.scala index c053b625236f6a731d175c5ebc8de7abd5c10ef9..0421b8147b99db3295921b23c86fbe2d45f430de 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/main/scala/org/apache/spark/sql/execution/command/OmniCacheCommand.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/main/scala/org/apache/spark/sql/execution/command/OmniMVCommand.scala @@ -17,13 +17,17 @@ package org.apache.spark.sql.execution.command -import com.huawei.boostkit.spark.conf.OmniCachePluginConfig -import com.huawei.boostkit.spark.conf.OmniCachePluginConfig._ +import com.huawei.boostkit.spark.conf.OmniMVPluginConfig +import com.huawei.boostkit.spark.conf.OmniMVPluginConfig._ import com.huawei.boostkit.spark.util.{RewriteHelper, ViewMetadata} -import java.io.IOException +import com.huawei.boostkit.spark.util.ViewMetadata._ +import com.huawei.boostkit.spark.util.lock.{FileLock, OmniMVAtomic} +import java.io.{FileNotFoundException, IOException} import java.net.URI +import java.rmi.UnexpectedException import java.util.Locale import org.apache.hadoop.fs.{FileSystem, Path} +import scala.collection.{mutable, JavaConverters} import scala.util.control.NonFatal import org.apache.spark.internal.io.FileCommitProtocol @@ -33,10 +37,11 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.getPartitionPathString import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.optimizer.OmniCacheToSparkAdapter._ +import org.apache.spark.sql.catalyst.optimizer.OmniMVToSparkAdapter._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.WashOutStrategy._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode @@ -44,7 +49,8 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql.util.SchemaUtils -case class OmniCacheCreateMvCommand( + +case class OmniMVCreateMvCommand( databaseNameOption: Option[String], name: String, providerStr: String, @@ -58,6 +64,7 @@ case class OmniCacheCreateMvCommand( override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { try { ViewMetadata.init(sparkSession) + loadViewCount() val sessionState = sparkSession.sessionState val databaseName = databaseNameOption.getOrElse(sessionState.catalog.getCurrentDatabase) val identifier = TableIdentifier(name, Option(databaseName)) @@ -67,7 +74,10 @@ case class OmniCacheCreateMvCommand( val table = buildCatalogTable( identifier, new StructType, - partitioning, None, properties, provider, None, + partitioning, None, + properties ++ Map(MV_LATEST_UPDATE_TIME -> + ViewMetadata.getViewDependsTableTimeStr(query)), + provider, None, comment, storageFormat, external = false) val tableIdentWithDB = identifier.copy(database = Some(databaseName)) @@ -118,6 +128,27 @@ case class OmniCacheCreateMvCommand( } CommandUtils.updateTableStats(sparkSession, table) + + // atomic save ViewMetadata.viewCnt + val dbName = table.identifier.database.getOrElse(DEFAULT_DATABASE) + val dbPath = new Path(metadataPath, dbName) + val dbViewCnt = new Path(dbPath, VIEW_CNT_FILE) + val fileLock = FileLock(fs, new Path(dbPath, VIEW_CNT_FILE_LOCK)) + OmniMVAtomic.funcWithSpinLock(fileLock) { + () => + val viewName = formatViewName(table.identifier) + if (fs.exists(dbViewCnt)) { + val curModifyTime = fs.getFileStatus(dbViewCnt).getModificationTime + if (ViewMetadata.getViewCntModifyTime(viewCnt).getOrElse(0L) != curModifyTime) { + loadViewCount(dbName) + } + } + ViewMetadata.viewCnt.put( + viewName, Array(0, System.currentTimeMillis(), UNLOAD)) + saveViewCountToFile(dbName) + loadViewCount(dbName) + } + ViewMetadata.addCatalogTableToCache(table) } catch { case e: Throwable => @@ -202,7 +233,26 @@ case class DropMaterializedViewCommand( catalog.refreshTable(tableName) catalog.dropTable(tableName, ifExists, purge) // remove mv from cache - ViewMetadata.removeMVCache(tableName) + ViewMetadata.deleteViewMetadata(tableName) + + // atomic del ViewMetadata.viewCnt + val dbName = tableName.database.getOrElse(DEFAULT_DATABASE) + val dbPath = new Path(metadataPath, dbName) + val dbViewCnt = new Path(dbPath, VIEW_CNT_FILE) + val filelock = FileLock(fs, new Path(dbPath, VIEW_CNT_FILE_LOCK)) + OmniMVAtomic.funcWithSpinLock(filelock) { + () => + val viewName = formatViewName(tableName) + if (fs.exists(dbViewCnt)) { + val curModifyTime = fs.getFileStatus(dbViewCnt).getModificationTime + if (ViewMetadata.getViewCntModifyTime(viewCnt).getOrElse(0L) != curModifyTime) { + loadViewCount(dbName) + } + } + ViewMetadata.viewCnt.remove(viewName) + saveViewCountToFile(dbName) + loadViewCount(dbName) + } } else if (ifExists) { // no-op } else { @@ -237,13 +287,13 @@ case class ShowMaterializedViewCommand( val catalog = sparkSession.sessionState.catalog val db = databaseName.getOrElse(catalog.getCurrentDatabase) - val omniCacheFilter: TableIdentifier => Boolean = { + val omniMVFilter: TableIdentifier => Boolean = { tableIdentifier => isMV(catalog.getTableMetadata(tableIdentifier)) } val tables = tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db)) - .filter(omniCacheFilter) + .filter(omniMVFilter) if (tableIdentifierPattern.isDefined && tables.isEmpty) { throw new AnalysisException(s"Table or view not found: ${tableIdentifierPattern.get}") } @@ -252,7 +302,7 @@ case class ShowMaterializedViewCommand( case Some(_) => Integer.MAX_VALUE case None => - OmniCachePluginConfig.getConf.showMVQuerySqlLen + OmniMVPluginConfig.getConf.showMVQuerySqlLen } tables.map { tableIdent => val properties = catalog.getTableMetadata(tableIdent).properties @@ -288,6 +338,7 @@ case class AlterRewriteMaterializedViewCommand( if (enableRewrite) { ViewMetadata.addCatalogTableToCache(newTable) } else { + ViewMetadata.addCatalogTableToCache(newTable) ViewMetadata.removeMVCache(tableName) } } else { @@ -308,7 +359,7 @@ case class RefreshMaterializedViewCommand( query: LogicalPlan, mode: SaveMode, catalogTable: Option[CatalogTable], - fileIndex: Option[FileIndex], + fileIndex: Seq[FileIndex], outputColumnNames: Seq[String]) extends DataWritingCommand { @@ -540,3 +591,141 @@ case class RefreshMaterializedViewCommand( }.toMap } } + +/** + * Eliminate the least used materialized view. + * + * The syntax of this command is: + * {{{ + * WASH OUT MATERIALIZED VIEW; + * }}} + */ +case class WashOutMaterializedViewCommand( + dropAll: Boolean, + strategy: Option[List[(String, Option[Int])]]) extends RunnableCommand { + + private val logFlag = "[OmniMV]" + + override def run(sparkSession: SparkSession): Seq[Row] = { + ViewMetadata.init(sparkSession) + loadViewCount() + if (dropAll) { + washOutAllMV() + return Seq.empty[Row] + } + if (strategy.isDefined) { + strategy.get.foreach { + infos: (String, Option[Int]) => + infos._1 match { + case UNUSED_DAYS => + washOutByUnUsedDays(infos._2) + case RESERVE_QUANTITY_BY_VIEW_COUNT => + washOutByReserveQuantity(infos._2) + case DROP_QUANTITY_BY_SPACE_CONSUMED => + washOutViewsBySpace(infos._2) + case _ => + } + } + } else { + // default wash out strategy. + washOutByUnUsedDays(Option.empty) + } + + // save wash out timestamp + ViewMetadata.washOutTimestamp = Some(System.currentTimeMillis()) + ViewMetadata.saveWashOutTimestamp() + + Seq.empty[Row] + } + + private def washOutAllMV(): Unit = { + ViewMetadata.viewCnt.forEach { + (viewName, _) => + ViewMetadata.spark.sql("DROP MATERIALIZED VIEW IF EXISTS " + viewName) + } + logInfo(f"$logFlag WASH OUT ALL MATERIALIZED VIEW.") + } + + private def washOutByUnUsedDays(para: Option[Int]): Unit = { + val unUsedDays = para.getOrElse( + OmniMVPluginConfig.getConf.minimumUnusedDaysForWashOut) + val curTime = System.currentTimeMillis() + val threshold = curTime - RewriteHelper.daysToMillisecond(unUsedDays.toLong) + ViewMetadata.viewCnt.forEach { + (viewName, viewInfo) => + if (viewInfo(1) <= threshold) { + ViewMetadata.spark.sql("DROP MATERIALIZED VIEW IF EXISTS " + viewName) + } + } + logInfo(f"$logFlag WASH OUT MATERIALIZED VIEW " + + f"USING $UNUSED_DAYS $unUsedDays.") + } + + private def washOutByReserveQuantity(para: Option[Int]): Unit = { + val reserveQuantity = para.getOrElse( + OmniMVPluginConfig.getConf.reserveViewQuantityByViewCount) + var viewCntList = JavaConverters.mapAsScalaMap(ViewMetadata.viewCnt).toList + if (viewCntList.size <= reserveQuantity) { + return + } + viewCntList = viewCntList.sorted { + (x: (String, Array[Long]), y: (String, Array[Long])) => { + if (y._2(0) != x._2(0)) { + y._2(0).compare(x._2(0)) + } else { + y._2(1).compare(x._2(1)) + } + } + } + for (i <- reserveQuantity until viewCntList.size) { + ViewMetadata.spark.sql("DROP MATERIALIZED VIEW IF EXISTS " + viewCntList(i)._1) + } + logInfo(f"$logFlag WASH OUT MATERIALIZED VIEW " + + f"USING $RESERVE_QUANTITY_BY_VIEW_COUNT $reserveQuantity.") + } + + private def washOutViewsBySpace(para: Option[Int]): Unit = { + val dropQuantity = para.getOrElse( + OmniMVPluginConfig.getConf.dropViewQuantityBySpaceConsumed) + val views = JavaConverters.mapAsScalaMap(ViewMetadata.viewCnt).toList.map(_._1) + val viewInfos = mutable.Map[String, Long]() + views.foreach { + view => + val dbName = view.split("\\.")(0) + val tableName = view.split("\\.")(1) + val tableLocation = ViewMetadata.spark.sessionState.catalog.defaultTablePath( + TableIdentifier(tableName, Some(dbName))) + var spaceConsumed = Long.MaxValue + try { + spaceConsumed = ViewMetadata.fs.getContentSummary( + new Path(tableLocation)).getSpaceConsumed + } catch { + case _: FileNotFoundException => + log.info(f"Can not find table: $tableName. It may have been deleted.") + case _ => + throw new UnexpectedException( + "Something unknown happens when wash out views by space") + } finally { + viewInfos.put(view, spaceConsumed) + } + } + val topN = viewInfos.toList.sorted { + (x: (String, Long), y: (String, Long)) => { + y._2.compare(x._2) + } + }.slice(0, dropQuantity) + topN.foreach { + view => + ViewMetadata.spark.sql("DROP MATERIALIZED VIEW IF EXISTS " + view._1) + } + logInfo(f"$logFlag WASH OUT MATERIALIZED VIEW " + + f"USING $DROP_QUANTITY_BY_SPACE_CONSUMED $dropQuantity.") + } + +} + +object WashOutStrategy { + val UNUSED_DAYS = "UNUSED_DAYS" + val RESERVE_QUANTITY_BY_VIEW_COUNT = "RESERVE_QUANTITY_BY_VIEW_COUNT" + val DROP_QUANTITY_BY_SPACE_CONSUMED = "DROP_QUANTITY_BY_SPACE_CONSUMED" +} diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q1.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q1.sql new file mode 100755 index 0000000000000000000000000000000000000000..4d20faad8ef58f9e0ebbc9cb5c0f7e3dc6c508b4 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q1.sql @@ -0,0 +1,19 @@ +WITH customer_total_return AS +( SELECT + sr_customer_sk AS ctr_customer_sk, + sr_store_sk AS ctr_store_sk, + sum(sr_return_amt) AS ctr_total_return + FROM store_returns, date_dim + WHERE sr_returned_date_sk = d_date_sk AND d_year = 2000 + GROUP BY sr_customer_sk, sr_store_sk) +SELECT c_customer_id +FROM customer_total_return ctr1, store, customer +WHERE ctr1.ctr_total_return > + (SELECT avg(ctr_total_return) * 1.2 + FROM customer_total_return ctr2 + WHERE ctr1.ctr_store_sk = ctr2.ctr_store_sk) + AND s_store_sk = ctr1.ctr_store_sk + AND s_state = 'TN' + AND ctr1.ctr_customer_sk = c_customer_sk +ORDER BY c_customer_id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q10.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q10.sql new file mode 100755 index 0000000000000000000000000000000000000000..5500e1aea1552564587eaafbbd8474f8a73b1390 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q10.sql @@ -0,0 +1,57 @@ +SELECT + cd_gender, + cd_marital_status, + cd_education_status, + count(*) cnt1, + cd_purchase_estimate, + count(*) cnt2, + cd_credit_rating, + count(*) cnt3, + cd_dep_count, + count(*) cnt4, + cd_dep_employed_count, + count(*) cnt5, + cd_dep_college_count, + count(*) cnt6 +FROM + customer c, customer_address ca, customer_demographics +WHERE + c.c_current_addr_sk = ca.ca_address_sk AND + ca_county IN ('Rush County', 'Toole County', 'Jefferson County', + 'Dona Ana County', 'La Porte County') AND + cd_demo_sk = c.c_current_cdemo_sk AND + exists(SELECT * + FROM store_sales, date_dim + WHERE c.c_customer_sk = ss_customer_sk AND + ss_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_moy BETWEEN 1 AND 1 + 3) AND + (exists(SELECT * + FROM web_sales, date_dim + WHERE c.c_customer_sk = ws_bill_customer_sk AND + ws_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_moy BETWEEN 1 AND 1 + 3) OR + exists(SELECT * + FROM catalog_sales, date_dim + WHERE c.c_customer_sk = cs_ship_customer_sk AND + cs_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_moy BETWEEN 1 AND 1 + 3)) +GROUP BY cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +ORDER BY cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q11.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q11.sql new file mode 100755 index 0000000000000000000000000000000000000000..3618fb14fa39c91f42c7e493c62f6963f26fa5a0 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q11.sql @@ -0,0 +1,68 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(ss_ext_list_price - ss_ext_discount_amt) year_total, + 's' sale_type + FROM customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk + AND ss_sold_date_sk = d_date_sk + GROUP BY c_customer_id + , c_first_name + , c_last_name + , d_year + , c_preferred_cust_flag + , c_birth_country + , c_login + , c_email_address + , d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(ws_ext_list_price - ws_ext_discount_amt) year_total, + 'w' sale_type + FROM customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk + AND ws_sold_date_sk = d_date_sk + GROUP BY + c_customer_id, c_first_name, c_last_name, c_preferred_cust_flag, c_birth_country, + c_login, c_email_address, d_year) +SELECT t_s_secyear.customer_preferred_cust_flag +FROM year_total t_s_firstyear + , year_total t_s_secyear + , year_total t_w_firstyear + , year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.dyear = 2001 + AND t_s_secyear.dyear = 2001 + 1 + AND t_w_firstyear.dyear = 2001 + AND t_w_secyear.dyear = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + ELSE NULL END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + ELSE NULL END +ORDER BY t_s_secyear.customer_preferred_cust_flag +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q12.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q12.sql new file mode 100755 index 0000000000000000000000000000000000000000..0382737f5aa2c0a0351fca850128efe8afdc816e --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q12.sql @@ -0,0 +1,22 @@ +SELECT + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ws_ext_sales_price) AS itemrevenue, + sum(ws_ext_sales_price) * 100 / sum(sum(ws_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM + web_sales, item, date_dim +WHERE + ws_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND ws_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) + AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY + i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY + i_category, i_class, i_item_id, i_item_desc, revenueratio +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q13.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q13.sql new file mode 100755 index 0000000000000000000000000000000000000000..32dc9e26097bad8735b3efc5b232fe8782b45142 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q13.sql @@ -0,0 +1,49 @@ +SELECT + avg(ss_quantity), + avg(ss_ext_sales_price), + avg(ss_ext_wholesale_cost), + sum(ss_ext_wholesale_cost) +FROM store_sales + , store + , customer_demographics + , household_demographics + , customer_address + , date_dim +WHERE s_store_sk = ss_store_sk + AND ss_sold_date_sk = d_date_sk AND d_year = 2001 + AND ((ss_hdemo_sk = hd_demo_sk + AND cd_demo_sk = ss_cdemo_sk + AND cd_marital_status = 'M' + AND cd_education_status = 'Advanced Degree' + AND ss_sales_price BETWEEN 100.00 AND 150.00 + AND hd_dep_count = 3 +) OR + (ss_hdemo_sk = hd_demo_sk + AND cd_demo_sk = ss_cdemo_sk + AND cd_marital_status = 'S' + AND cd_education_status = 'College' + AND ss_sales_price BETWEEN 50.00 AND 100.00 + AND hd_dep_count = 1 + ) OR + (ss_hdemo_sk = hd_demo_sk + AND cd_demo_sk = ss_cdemo_sk + AND cd_marital_status = 'W' + AND cd_education_status = '2 yr Degree' + AND ss_sales_price BETWEEN 150.00 AND 200.00 + AND hd_dep_count = 1 + )) + AND ((ss_addr_sk = ca_address_sk + AND ca_country = 'United States' + AND ca_state IN ('TX', 'OH', 'TX') + AND ss_net_profit BETWEEN 100 AND 200 +) OR + (ss_addr_sk = ca_address_sk + AND ca_country = 'United States' + AND ca_state IN ('OR', 'NM', 'KY') + AND ss_net_profit BETWEEN 150 AND 300 + ) OR + (ss_addr_sk = ca_address_sk + AND ca_country = 'United States' + AND ca_state IN ('VA', 'TX', 'MS') + AND ss_net_profit BETWEEN 50 AND 250 + )) diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q14a.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q14a.sql new file mode 100755 index 0000000000000000000000000000000000000000..954ddd41be0e6debd3c9485598b23eaa55d29753 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q14a.sql @@ -0,0 +1,120 @@ +WITH cross_items AS +(SELECT i_item_sk ss_item_sk + FROM item, + (SELECT + iss.i_brand_id brand_id, + iss.i_class_id class_id, + iss.i_category_id category_id + FROM store_sales, item iss, date_dim d1 + WHERE ss_item_sk = iss.i_item_sk + AND ss_sold_date_sk = d1.d_date_sk + AND d1.d_year BETWEEN 1999 AND 1999 + 2 + INTERSECT + SELECT + ics.i_brand_id, + ics.i_class_id, + ics.i_category_id + FROM catalog_sales, item ics, date_dim d2 + WHERE cs_item_sk = ics.i_item_sk + AND cs_sold_date_sk = d2.d_date_sk + AND d2.d_year BETWEEN 1999 AND 1999 + 2 + INTERSECT + SELECT + iws.i_brand_id, + iws.i_class_id, + iws.i_category_id + FROM web_sales, item iws, date_dim d3 + WHERE ws_item_sk = iws.i_item_sk + AND ws_sold_date_sk = d3.d_date_sk + AND d3.d_year BETWEEN 1999 AND 1999 + 2) x + WHERE i_brand_id = brand_id + AND i_class_id = class_id + AND i_category_id = category_id +), + avg_sales AS + (SELECT avg(quantity * list_price) average_sales + FROM ( + SELECT + ss_quantity quantity, + ss_list_price list_price + FROM store_sales, date_dim + WHERE ss_sold_date_sk = d_date_sk + AND d_year BETWEEN 1999 AND 2001 + UNION ALL + SELECT + cs_quantity quantity, + cs_list_price list_price + FROM catalog_sales, date_dim + WHERE cs_sold_date_sk = d_date_sk + AND d_year BETWEEN 1999 AND 1999 + 2 + UNION ALL + SELECT + ws_quantity quantity, + ws_list_price list_price + FROM web_sales, date_dim + WHERE ws_sold_date_sk = d_date_sk + AND d_year BETWEEN 1999 AND 1999 + 2) x) +SELECT + channel, + i_brand_id, + i_class_id, + i_category_id, + sum(sales), + sum(number_sales) +FROM ( + SELECT + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + FROM store_sales, item, date_dim + WHERE ss_item_sk IN (SELECT ss_item_sk + FROM cross_items) + AND ss_item_sk = i_item_sk + AND ss_sold_date_sk = d_date_sk + AND d_year = 1999 + 2 + AND d_moy = 11 + GROUP BY i_brand_id, i_class_id, i_category_id + HAVING sum(ss_quantity * ss_list_price) > (SELECT average_sales + FROM avg_sales) + UNION ALL + SELECT + 'catalog' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(cs_quantity * cs_list_price) sales, + count(*) number_sales + FROM catalog_sales, item, date_dim + WHERE cs_item_sk IN (SELECT ss_item_sk + FROM cross_items) + AND cs_item_sk = i_item_sk + AND cs_sold_date_sk = d_date_sk + AND d_year = 1999 + 2 + AND d_moy = 11 + GROUP BY i_brand_id, i_class_id, i_category_id + HAVING sum(cs_quantity * cs_list_price) > (SELECT average_sales FROM avg_sales) + UNION ALL + SELECT + 'web' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ws_quantity * ws_list_price) sales, + count(*) number_sales + FROM web_sales, item, date_dim + WHERE ws_item_sk IN (SELECT ss_item_sk + FROM cross_items) + AND ws_item_sk = i_item_sk + AND ws_sold_date_sk = d_date_sk + AND d_year = 1999 + 2 + AND d_moy = 11 + GROUP BY i_brand_id, i_class_id, i_category_id + HAVING sum(ws_quantity * ws_list_price) > (SELECT average_sales + FROM avg_sales) + ) y +GROUP BY ROLLUP (channel, i_brand_id, i_class_id, i_category_id) +ORDER BY channel, i_brand_id, i_class_id, i_category_id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q14b.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q14b.sql new file mode 100755 index 0000000000000000000000000000000000000000..929a8484bf9b4f602f8ef5ffe334a59f3354fbe3 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q14b.sql @@ -0,0 +1,95 @@ +WITH cross_items AS +(SELECT i_item_sk ss_item_sk + FROM item, + (SELECT + iss.i_brand_id brand_id, + iss.i_class_id class_id, + iss.i_category_id category_id + FROM store_sales, item iss, date_dim d1 + WHERE ss_item_sk = iss.i_item_sk + AND ss_sold_date_sk = d1.d_date_sk + AND d1.d_year BETWEEN 1999 AND 1999 + 2 + INTERSECT + SELECT + ics.i_brand_id, + ics.i_class_id, + ics.i_category_id + FROM catalog_sales, item ics, date_dim d2 + WHERE cs_item_sk = ics.i_item_sk + AND cs_sold_date_sk = d2.d_date_sk + AND d2.d_year BETWEEN 1999 AND 1999 + 2 + INTERSECT + SELECT + iws.i_brand_id, + iws.i_class_id, + iws.i_category_id + FROM web_sales, item iws, date_dim d3 + WHERE ws_item_sk = iws.i_item_sk + AND ws_sold_date_sk = d3.d_date_sk + AND d3.d_year BETWEEN 1999 AND 1999 + 2) x + WHERE i_brand_id = brand_id + AND i_class_id = class_id + AND i_category_id = category_id +), + avg_sales AS + (SELECT avg(quantity * list_price) average_sales + FROM (SELECT + ss_quantity quantity, + ss_list_price list_price + FROM store_sales, date_dim + WHERE ss_sold_date_sk = d_date_sk AND d_year BETWEEN 1999 AND 1999 + 2 + UNION ALL + SELECT + cs_quantity quantity, + cs_list_price list_price + FROM catalog_sales, date_dim + WHERE cs_sold_date_sk = d_date_sk AND d_year BETWEEN 1999 AND 1999 + 2 + UNION ALL + SELECT + ws_quantity quantity, + ws_list_price list_price + FROM web_sales, date_dim + WHERE ws_sold_date_sk = d_date_sk AND d_year BETWEEN 1999 AND 1999 + 2) x) +SELECT * +FROM + (SELECT + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + FROM store_sales, item, date_dim + WHERE ss_item_sk IN (SELECT ss_item_sk + FROM cross_items) + AND ss_item_sk = i_item_sk + AND ss_sold_date_sk = d_date_sk + AND d_week_seq = (SELECT d_week_seq + FROM date_dim + WHERE d_year = 1999 + 1 AND d_moy = 12 AND d_dom = 11) + GROUP BY i_brand_id, i_class_id, i_category_id + HAVING sum(ss_quantity * ss_list_price) > (SELECT average_sales + FROM avg_sales)) this_year, + (SELECT + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + FROM store_sales, item, date_dim + WHERE ss_item_sk IN (SELECT ss_item_sk + FROM cross_items) + AND ss_item_sk = i_item_sk + AND ss_sold_date_sk = d_date_sk + AND d_week_seq = (SELECT d_week_seq + FROM date_dim + WHERE d_year = 1999 AND d_moy = 12 AND d_dom = 11) + GROUP BY i_brand_id, i_class_id, i_category_id + HAVING sum(ss_quantity * ss_list_price) > (SELECT average_sales + FROM avg_sales)) last_year +WHERE this_year.i_brand_id = last_year.i_brand_id + AND this_year.i_class_id = last_year.i_class_id + AND this_year.i_category_id = last_year.i_category_id +ORDER BY this_year.channel, this_year.i_brand_id, this_year.i_class_id, this_year.i_category_id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q15.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q15.sql new file mode 100755 index 0000000000000000000000000000000000000000..b8182e23b019531e619e0f328ddc9ada8cfd1b99 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q15.sql @@ -0,0 +1,15 @@ +SELECT + ca_zip, + sum(cs_sales_price) +FROM catalog_sales, customer, customer_address, date_dim +WHERE cs_bill_customer_sk = c_customer_sk + AND c_current_addr_sk = ca_address_sk + AND (substr(ca_zip, 1, 5) IN ('85669', '86197', '88274', '83405', '86475', + '85392', '85460', '80348', '81792') + OR ca_state IN ('CA', 'WA', 'GA') + OR cs_sales_price > 500) + AND cs_sold_date_sk = d_date_sk + AND d_qoy = 2 AND d_year = 2001 +GROUP BY ca_zip +ORDER BY ca_zip +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q16.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q16.sql new file mode 100755 index 0000000000000000000000000000000000000000..732ad0d84807181e208a65f8d0408664b16edca8 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q16.sql @@ -0,0 +1,23 @@ +SELECT + count(DISTINCT cs_order_number) AS `order count `, + sum(cs_ext_ship_cost) AS `total shipping cost `, + sum(cs_net_profit) AS `total net profit ` +FROM + catalog_sales cs1, date_dim, customer_address, call_center +WHERE + d_date BETWEEN '2002-02-01' AND (CAST('2002-02-01' AS DATE) + INTERVAL 60 days) + AND cs1.cs_ship_date_sk = d_date_sk + AND cs1.cs_ship_addr_sk = ca_address_sk + AND ca_state = 'GA' + AND cs1.cs_call_center_sk = cc_call_center_sk + AND cc_county IN + ('Williamson County', 'Williamson County', 'Williamson County', 'Williamson County', 'Williamson County') + AND EXISTS(SELECT * + FROM catalog_sales cs2 + WHERE cs1.cs_order_number = cs2.cs_order_number + AND cs1.cs_warehouse_sk <> cs2.cs_warehouse_sk) + AND NOT EXISTS(SELECT * + FROM catalog_returns cr1 + WHERE cs1.cs_order_number = cr1.cr_order_number) +ORDER BY count(DISTINCT cs_order_number) +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q17.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q17.sql new file mode 100755 index 0000000000000000000000000000000000000000..4d647f795600494f88523bfe39f157e2a6edbe96 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q17.sql @@ -0,0 +1,33 @@ +SELECT + i_item_id, + i_item_desc, + s_state, + count(ss_quantity) AS store_sales_quantitycount, + avg(ss_quantity) AS store_sales_quantityave, + stddev_samp(ss_quantity) AS store_sales_quantitystdev, + stddev_samp(ss_quantity) / avg(ss_quantity) AS store_sales_quantitycov, + count(sr_return_quantity) as_store_returns_quantitycount, + avg(sr_return_quantity) as_store_returns_quantityave, + stddev_samp(sr_return_quantity) as_store_returns_quantitystdev, + stddev_samp(sr_return_quantity) / avg(sr_return_quantity) AS store_returns_quantitycov, + count(cs_quantity) AS catalog_sales_quantitycount, + avg(cs_quantity) AS catalog_sales_quantityave, + stddev_samp(cs_quantity) / avg(cs_quantity) AS catalog_sales_quantitystdev, + stddev_samp(cs_quantity) / avg(cs_quantity) AS catalog_sales_quantitycov +FROM store_sales, store_returns, catalog_sales, date_dim d1, date_dim d2, date_dim d3, store, item +WHERE d1.d_quarter_name = '2001Q1' + AND d1.d_date_sk = ss_sold_date_sk + AND i_item_sk = ss_item_sk + AND s_store_sk = ss_store_sk + AND ss_customer_sk = sr_customer_sk + AND ss_item_sk = sr_item_sk + AND ss_ticket_number = sr_ticket_number + AND sr_returned_date_sk = d2.d_date_sk + AND d2.d_quarter_name IN ('2001Q1', '2001Q2', '2001Q3') + AND sr_customer_sk = cs_bill_customer_sk + AND sr_item_sk = cs_item_sk + AND cs_sold_date_sk = d3.d_date_sk + AND d3.d_quarter_name IN ('2001Q1', '2001Q2', '2001Q3') +GROUP BY i_item_id, i_item_desc, s_state +ORDER BY i_item_id, i_item_desc, s_state +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q18.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q18.sql new file mode 100755 index 0000000000000000000000000000000000000000..4055c80fdef51063b0fdce504fbc29051154359f --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q18.sql @@ -0,0 +1,28 @@ +SELECT + i_item_id, + ca_country, + ca_state, + ca_county, + avg(cast(cs_quantity AS DECIMAL(12, 2))) agg1, + avg(cast(cs_list_price AS DECIMAL(12, 2))) agg2, + avg(cast(cs_coupon_amt AS DECIMAL(12, 2))) agg3, + avg(cast(cs_sales_price AS DECIMAL(12, 2))) agg4, + avg(cast(cs_net_profit AS DECIMAL(12, 2))) agg5, + avg(cast(c_birth_year AS DECIMAL(12, 2))) agg6, + avg(cast(cd1.cd_dep_count AS DECIMAL(12, 2))) agg7 +FROM catalog_sales, customer_demographics cd1, + customer_demographics cd2, customer, customer_address, date_dim, item +WHERE cs_sold_date_sk = d_date_sk AND + cs_item_sk = i_item_sk AND + cs_bill_cdemo_sk = cd1.cd_demo_sk AND + cs_bill_customer_sk = c_customer_sk AND + cd1.cd_gender = 'F' AND + cd1.cd_education_status = 'Unknown' AND + c_current_cdemo_sk = cd2.cd_demo_sk AND + c_current_addr_sk = ca_address_sk AND + c_birth_month IN (1, 6, 8, 9, 12, 2) AND + d_year = 1998 AND + ca_state IN ('MS', 'IN', 'ND', 'OK', 'NM', 'VA', 'MS') +GROUP BY ROLLUP (i_item_id, ca_country, ca_state, ca_county) +ORDER BY ca_country, ca_state, ca_county, i_item_id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q19.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q19.sql new file mode 100755 index 0000000000000000000000000000000000000000..e38ab7f2683f4be01ce5b9782d2fbbf014de9436 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q19.sql @@ -0,0 +1,19 @@ +SELECT + i_brand_id brand_id, + i_brand brand, + i_manufact_id, + i_manufact, + sum(ss_ext_sales_price) ext_price +FROM date_dim, store_sales, item, customer, customer_address, store +WHERE d_date_sk = ss_sold_date_sk + AND ss_item_sk = i_item_sk + AND i_manager_id = 8 + AND d_moy = 11 + AND d_year = 1998 + AND ss_customer_sk = c_customer_sk + AND c_current_addr_sk = ca_address_sk + AND substr(ca_zip, 1, 5) <> substr(s_zip, 1, 5) + AND ss_store_sk = s_store_sk +GROUP BY i_brand, i_brand_id, i_manufact_id, i_manufact +ORDER BY ext_price DESC, brand, brand_id, i_manufact_id, i_manufact +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q2.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q2.sql new file mode 100755 index 0000000000000000000000000000000000000000..52c0e90c467407c0359427026dad4f0c28d283b0 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q2.sql @@ -0,0 +1,81 @@ +WITH wscs AS +( SELECT + sold_date_sk, + sales_price + FROM (SELECT + ws_sold_date_sk sold_date_sk, + ws_ext_sales_price sales_price + FROM web_sales) x + UNION ALL + (SELECT + cs_sold_date_sk sold_date_sk, + cs_ext_sales_price sales_price + FROM catalog_sales)), + wswscs AS + ( SELECT + d_week_seq, + sum(CASE WHEN (d_day_name = 'Sunday') + THEN sales_price + ELSE NULL END) + sun_sales, + sum(CASE WHEN (d_day_name = 'Monday') + THEN sales_price + ELSE NULL END) + mon_sales, + sum(CASE WHEN (d_day_name = 'Tuesday') + THEN sales_price + ELSE NULL END) + tue_sales, + sum(CASE WHEN (d_day_name = 'Wednesday') + THEN sales_price + ELSE NULL END) + wed_sales, + sum(CASE WHEN (d_day_name = 'Thursday') + THEN sales_price + ELSE NULL END) + thu_sales, + sum(CASE WHEN (d_day_name = 'Friday') + THEN sales_price + ELSE NULL END) + fri_sales, + sum(CASE WHEN (d_day_name = 'Saturday') + THEN sales_price + ELSE NULL END) + sat_sales + FROM wscs, date_dim + WHERE d_date_sk = sold_date_sk + GROUP BY d_week_seq) +SELECT + d_week_seq1, + round(sun_sales1 / sun_sales2, 2), + round(mon_sales1 / mon_sales2, 2), + round(tue_sales1 / tue_sales2, 2), + round(wed_sales1 / wed_sales2, 2), + round(thu_sales1 / thu_sales2, 2), + round(fri_sales1 / fri_sales2, 2), + round(sat_sales1 / sat_sales2, 2) +FROM + (SELECT + wswscs.d_week_seq d_week_seq1, + sun_sales sun_sales1, + mon_sales mon_sales1, + tue_sales tue_sales1, + wed_sales wed_sales1, + thu_sales thu_sales1, + fri_sales fri_sales1, + sat_sales sat_sales1 + FROM wswscs, date_dim + WHERE date_dim.d_week_seq = wswscs.d_week_seq AND d_year = 2001) y, + (SELECT + wswscs.d_week_seq d_week_seq2, + sun_sales sun_sales2, + mon_sales mon_sales2, + tue_sales tue_sales2, + wed_sales wed_sales2, + thu_sales thu_sales2, + fri_sales fri_sales2, + sat_sales sat_sales2 + FROM wswscs, date_dim + WHERE date_dim.d_week_seq = wswscs.d_week_seq AND d_year = 2001 + 1) z +WHERE d_week_seq1 = d_week_seq2 - 53 +ORDER BY d_week_seq1 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q20.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q20.sql new file mode 100755 index 0000000000000000000000000000000000000000..7ac6c7a75d8ea789b566f0f16d8198cc70fae1df --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q20.sql @@ -0,0 +1,18 @@ +SELECT + i_item_desc, + i_category, + i_class, + i_current_price, + sum(cs_ext_sales_price) AS itemrevenue, + sum(cs_ext_sales_price) * 100 / sum(sum(cs_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM catalog_sales, item, date_dim +WHERE cs_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND cs_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) +AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY i_category, i_class, i_item_id, i_item_desc, revenueratio +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q21.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q21.sql new file mode 100755 index 0000000000000000000000000000000000000000..550881143f8099e82688da5424ed393ea6ec6d5b --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q21.sql @@ -0,0 +1,25 @@ +SELECT * +FROM ( + SELECT + w_warehouse_name, + i_item_id, + sum(CASE WHEN (cast(d_date AS DATE) < cast('2000-03-11' AS DATE)) + THEN inv_quantity_on_hand + ELSE 0 END) AS inv_before, + sum(CASE WHEN (cast(d_date AS DATE) >= cast('2000-03-11' AS DATE)) + THEN inv_quantity_on_hand + ELSE 0 END) AS inv_after + FROM inventory, warehouse, item, date_dim + WHERE i_current_price BETWEEN 0.99 AND 1.49 + AND i_item_sk = inv_item_sk + AND inv_warehouse_sk = w_warehouse_sk + AND inv_date_sk = d_date_sk + AND d_date BETWEEN (cast('2000-03-11' AS DATE) - INTERVAL 30 days) + AND (cast('2000-03-11' AS DATE) + INTERVAL 30 days) + GROUP BY w_warehouse_name, i_item_id) x +WHERE (CASE WHEN inv_before > 0 + THEN inv_after / inv_before + ELSE NULL + END) BETWEEN 2.0 / 3.0 AND 3.0 / 2.0 +ORDER BY w_warehouse_name, i_item_id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q22.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q22.sql new file mode 100755 index 0000000000000000000000000000000000000000..add3b41f7c76c6e61b0dcaf1717d6eebc15c744f --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q22.sql @@ -0,0 +1,14 @@ +SELECT + i_product_name, + i_brand, + i_class, + i_category, + avg(inv_quantity_on_hand) qoh +FROM inventory, date_dim, item, warehouse +WHERE inv_date_sk = d_date_sk + AND inv_item_sk = i_item_sk + AND inv_warehouse_sk = w_warehouse_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 +GROUP BY ROLLUP (i_product_name, i_brand, i_class, i_category) +ORDER BY qoh, i_product_name, i_brand, i_class, i_category +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q23a.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q23a.sql new file mode 100755 index 0000000000000000000000000000000000000000..37791f643375ccaa5532007edfa13bef5a5d1c84 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q23a.sql @@ -0,0 +1,53 @@ +WITH frequent_ss_items AS +(SELECT + substr(i_item_desc, 1, 30) itemdesc, + i_item_sk item_sk, + d_date solddate, + count(*) cnt + FROM store_sales, date_dim, item + WHERE ss_sold_date_sk = d_date_sk + AND ss_item_sk = i_item_sk + AND d_year IN (2000, 2000 + 1, 2000 + 2, 2000 + 3) + GROUP BY substr(i_item_desc, 1, 30), i_item_sk, d_date + HAVING count(*) > 4), + max_store_sales AS + (SELECT max(csales) tpcds_cmax + FROM (SELECT + c_customer_sk, + sum(ss_quantity * ss_sales_price) csales + FROM store_sales, customer, date_dim + WHERE ss_customer_sk = c_customer_sk + AND ss_sold_date_sk = d_date_sk + AND d_year IN (2000, 2000 + 1, 2000 + 2, 2000 + 3) + GROUP BY c_customer_sk) x), + best_ss_customer AS + (SELECT + c_customer_sk, + sum(ss_quantity * ss_sales_price) ssales + FROM store_sales, customer + WHERE ss_customer_sk = c_customer_sk + GROUP BY c_customer_sk + HAVING sum(ss_quantity * ss_sales_price) > (50 / 100.0) * + (SELECT * + FROM max_store_sales)) +SELECT sum(sales) +FROM ((SELECT cs_quantity * cs_list_price sales +FROM catalog_sales, date_dim +WHERE d_year = 2000 + AND d_moy = 2 + AND cs_sold_date_sk = d_date_sk + AND cs_item_sk IN (SELECT item_sk +FROM frequent_ss_items) + AND cs_bill_customer_sk IN (SELECT c_customer_sk +FROM best_ss_customer)) + UNION ALL + (SELECT ws_quantity * ws_list_price sales + FROM web_sales, date_dim + WHERE d_year = 2000 + AND d_moy = 2 + AND ws_sold_date_sk = d_date_sk + AND ws_item_sk IN (SELECT item_sk + FROM frequent_ss_items) + AND ws_bill_customer_sk IN (SELECT c_customer_sk + FROM best_ss_customer))) y +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q23b.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q23b.sql new file mode 100755 index 0000000000000000000000000000000000000000..01150197af2ba0405bbf70af22df860def8e3a17 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q23b.sql @@ -0,0 +1,68 @@ +WITH frequent_ss_items AS +(SELECT + substr(i_item_desc, 1, 30) itemdesc, + i_item_sk item_sk, + d_date solddate, + count(*) cnt + FROM store_sales, date_dim, item + WHERE ss_sold_date_sk = d_date_sk + AND ss_item_sk = i_item_sk + AND d_year IN (2000, 2000 + 1, 2000 + 2, 2000 + 3) + GROUP BY substr(i_item_desc, 1, 30), i_item_sk, d_date + HAVING count(*) > 4), + max_store_sales AS + (SELECT max(csales) tpcds_cmax + FROM (SELECT + c_customer_sk, + sum(ss_quantity * ss_sales_price) csales + FROM store_sales, customer, date_dim + WHERE ss_customer_sk = c_customer_sk + AND ss_sold_date_sk = d_date_sk + AND d_year IN (2000, 2000 + 1, 2000 + 2, 2000 + 3) + GROUP BY c_customer_sk) x), + best_ss_customer AS + (SELECT + c_customer_sk, + sum(ss_quantity * ss_sales_price) ssales + FROM store_sales + , customer + WHERE ss_customer_sk = c_customer_sk + GROUP BY c_customer_sk + HAVING sum(ss_quantity * ss_sales_price) > (50 / 100.0) * + (SELECT * + FROM max_store_sales)) +SELECT + c_last_name, + c_first_name, + sales +FROM ((SELECT + c_last_name, + c_first_name, + sum(cs_quantity * cs_list_price) sales +FROM catalog_sales, customer, date_dim +WHERE d_year = 2000 + AND d_moy = 2 + AND cs_sold_date_sk = d_date_sk + AND cs_item_sk IN (SELECT item_sk +FROM frequent_ss_items) + AND cs_bill_customer_sk IN (SELECT c_customer_sk +FROM best_ss_customer) + AND cs_bill_customer_sk = c_customer_sk +GROUP BY c_last_name, c_first_name) + UNION ALL + (SELECT + c_last_name, + c_first_name, + sum(ws_quantity * ws_list_price) sales + FROM web_sales, customer, date_dim + WHERE d_year = 2000 + AND d_moy = 2 + AND ws_sold_date_sk = d_date_sk + AND ws_item_sk IN (SELECT item_sk + FROM frequent_ss_items) + AND ws_bill_customer_sk IN (SELECT c_customer_sk + FROM best_ss_customer) + AND ws_bill_customer_sk = c_customer_sk + GROUP BY c_last_name, c_first_name)) y +ORDER BY c_last_name, c_first_name, sales +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q24a.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q24a.sql new file mode 100755 index 0000000000000000000000000000000000000000..bcc189486634da0b24ba98b7a13c793a15083dac --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q24a.sql @@ -0,0 +1,34 @@ +WITH ssales AS +(SELECT + c_last_name, + c_first_name, + s_store_name, + ca_state, + s_state, + i_color, + i_current_price, + i_manager_id, + i_units, + i_size, + sum(ss_net_paid) netpaid + FROM store_sales, store_returns, store, item, customer, customer_address + WHERE ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk + AND ss_customer_sk = c_customer_sk + AND ss_item_sk = i_item_sk + AND ss_store_sk = s_store_sk + AND c_birth_country = upper(ca_country) + AND s_zip = ca_zip + AND s_market_id = 8 + GROUP BY c_last_name, c_first_name, s_store_name, ca_state, s_state, i_color, + i_current_price, i_manager_id, i_units, i_size) +SELECT + c_last_name, + c_first_name, + s_store_name, + sum(netpaid) paid +FROM ssales +WHERE i_color = 'pale' +GROUP BY c_last_name, c_first_name, s_store_name +HAVING sum(netpaid) > (SELECT 0.05 * avg(netpaid) +FROM ssales) diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q24b.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q24b.sql new file mode 100755 index 0000000000000000000000000000000000000000..830eb670bcdd220c8ed938292d13e05ba46ec0fb --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q24b.sql @@ -0,0 +1,34 @@ +WITH ssales AS +(SELECT + c_last_name, + c_first_name, + s_store_name, + ca_state, + s_state, + i_color, + i_current_price, + i_manager_id, + i_units, + i_size, + sum(ss_net_paid) netpaid + FROM store_sales, store_returns, store, item, customer, customer_address + WHERE ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk + AND ss_customer_sk = c_customer_sk + AND ss_item_sk = i_item_sk + AND ss_store_sk = s_store_sk + AND c_birth_country = upper(ca_country) + AND s_zip = ca_zip + AND s_market_id = 8 + GROUP BY c_last_name, c_first_name, s_store_name, ca_state, s_state, + i_color, i_current_price, i_manager_id, i_units, i_size) +SELECT + c_last_name, + c_first_name, + s_store_name, + sum(netpaid) paid +FROM ssales +WHERE i_color = 'chiffon' +GROUP BY c_last_name, c_first_name, s_store_name +HAVING sum(netpaid) > (SELECT 0.05 * avg(netpaid) +FROM ssales) diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q25.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q25.sql new file mode 100755 index 0000000000000000000000000000000000000000..a4d78a3c56adc4ae0fcf15677d270d6553a9b954 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q25.sql @@ -0,0 +1,33 @@ +SELECT + i_item_id, + i_item_desc, + s_store_id, + s_store_name, + sum(ss_net_profit) AS store_sales_profit, + sum(sr_net_loss) AS store_returns_loss, + sum(cs_net_profit) AS catalog_sales_profit +FROM + store_sales, store_returns, catalog_sales, date_dim d1, date_dim d2, date_dim d3, + store, item +WHERE + d1.d_moy = 4 + AND d1.d_year = 2001 + AND d1.d_date_sk = ss_sold_date_sk + AND i_item_sk = ss_item_sk + AND s_store_sk = ss_store_sk + AND ss_customer_sk = sr_customer_sk + AND ss_item_sk = sr_item_sk + AND ss_ticket_number = sr_ticket_number + AND sr_returned_date_sk = d2.d_date_sk + AND d2.d_moy BETWEEN 4 AND 10 + AND d2.d_year = 2001 + AND sr_customer_sk = cs_bill_customer_sk + AND sr_item_sk = cs_item_sk + AND cs_sold_date_sk = d3.d_date_sk + AND d3.d_moy BETWEEN 4 AND 10 + AND d3.d_year = 2001 +GROUP BY + i_item_id, i_item_desc, s_store_id, s_store_name +ORDER BY + i_item_id, i_item_desc, s_store_id, s_store_name +LIMIT 100 \ No newline at end of file diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q26.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q26.sql new file mode 100755 index 0000000000000000000000000000000000000000..6d395a1d791dd2a0eae8b538130e55df7eda20b0 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q26.sql @@ -0,0 +1,19 @@ +SELECT + i_item_id, + avg(cs_quantity) agg1, + avg(cs_list_price) agg2, + avg(cs_coupon_amt) agg3, + avg(cs_sales_price) agg4 +FROM catalog_sales, customer_demographics, date_dim, item, promotion +WHERE cs_sold_date_sk = d_date_sk AND + cs_item_sk = i_item_sk AND + cs_bill_cdemo_sk = cd_demo_sk AND + cs_promo_sk = p_promo_sk AND + cd_gender = 'M' AND + cd_marital_status = 'S' AND + cd_education_status = 'College' AND + (p_channel_email = 'N' OR p_channel_event = 'N') AND + d_year = 2000 +GROUP BY i_item_id +ORDER BY i_item_id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q27.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q27.sql new file mode 100755 index 0000000000000000000000000000000000000000..b0e2fd95fd159f972ee5ccc62d414c28c13f8751 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q27.sql @@ -0,0 +1,21 @@ +SELECT + i_item_id, + s_state, + grouping(s_state) g_state, + avg(ss_quantity) agg1, + avg(ss_list_price) agg2, + avg(ss_coupon_amt) agg3, + avg(ss_sales_price) agg4 +FROM store_sales, customer_demographics, date_dim, store, item +WHERE ss_sold_date_sk = d_date_sk AND + ss_item_sk = i_item_sk AND + ss_store_sk = s_store_sk AND + ss_cdemo_sk = cd_demo_sk AND + cd_gender = 'M' AND + cd_marital_status = 'S' AND + cd_education_status = 'College' AND + d_year = 2002 AND + s_state IN ('TN', 'TN', 'TN', 'TN', 'TN', 'TN') +GROUP BY ROLLUP (i_item_id, s_state) +ORDER BY i_item_id, s_state +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q28.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q28.sql new file mode 100755 index 0000000000000000000000000000000000000000..f34c2bb0e34e102371c776797b08bd8ac293a836 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q28.sql @@ -0,0 +1,56 @@ +SELECT * +FROM (SELECT + avg(ss_list_price) B1_LP, + count(ss_list_price) B1_CNT, + count(DISTINCT ss_list_price) B1_CNTD +FROM store_sales +WHERE ss_quantity BETWEEN 0 AND 5 + AND (ss_list_price BETWEEN 8 AND 8 + 10 + OR ss_coupon_amt BETWEEN 459 AND 459 + 1000 + OR ss_wholesale_cost BETWEEN 57 AND 57 + 20)) B1, + (SELECT + avg(ss_list_price) B2_LP, + count(ss_list_price) B2_CNT, + count(DISTINCT ss_list_price) B2_CNTD + FROM store_sales + WHERE ss_quantity BETWEEN 6 AND 10 + AND (ss_list_price BETWEEN 90 AND 90 + 10 + OR ss_coupon_amt BETWEEN 2323 AND 2323 + 1000 + OR ss_wholesale_cost BETWEEN 31 AND 31 + 20)) B2, + (SELECT + avg(ss_list_price) B3_LP, + count(ss_list_price) B3_CNT, + count(DISTINCT ss_list_price) B3_CNTD + FROM store_sales + WHERE ss_quantity BETWEEN 11 AND 15 + AND (ss_list_price BETWEEN 142 AND 142 + 10 + OR ss_coupon_amt BETWEEN 12214 AND 12214 + 1000 + OR ss_wholesale_cost BETWEEN 79 AND 79 + 20)) B3, + (SELECT + avg(ss_list_price) B4_LP, + count(ss_list_price) B4_CNT, + count(DISTINCT ss_list_price) B4_CNTD + FROM store_sales + WHERE ss_quantity BETWEEN 16 AND 20 + AND (ss_list_price BETWEEN 135 AND 135 + 10 + OR ss_coupon_amt BETWEEN 6071 AND 6071 + 1000 + OR ss_wholesale_cost BETWEEN 38 AND 38 + 20)) B4, + (SELECT + avg(ss_list_price) B5_LP, + count(ss_list_price) B5_CNT, + count(DISTINCT ss_list_price) B5_CNTD + FROM store_sales + WHERE ss_quantity BETWEEN 21 AND 25 + AND (ss_list_price BETWEEN 122 AND 122 + 10 + OR ss_coupon_amt BETWEEN 836 AND 836 + 1000 + OR ss_wholesale_cost BETWEEN 17 AND 17 + 20)) B5, + (SELECT + avg(ss_list_price) B6_LP, + count(ss_list_price) B6_CNT, + count(DISTINCT ss_list_price) B6_CNTD + FROM store_sales + WHERE ss_quantity BETWEEN 26 AND 30 + AND (ss_list_price BETWEEN 154 AND 154 + 10 + OR ss_coupon_amt BETWEEN 7326 AND 7326 + 1000 + OR ss_wholesale_cost BETWEEN 7 AND 7 + 20)) B6 +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q29.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q29.sql new file mode 100755 index 0000000000000000000000000000000000000000..3f1fd553f6da8aca54667c1f6854ded0d5bd922a --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q29.sql @@ -0,0 +1,32 @@ +SELECT + i_item_id, + i_item_desc, + s_store_id, + s_store_name, + sum(ss_quantity) AS store_sales_quantity, + sum(sr_return_quantity) AS store_returns_quantity, + sum(cs_quantity) AS catalog_sales_quantity +FROM + store_sales, store_returns, catalog_sales, date_dim d1, date_dim d2, + date_dim d3, store, item +WHERE + d1.d_moy = 9 + AND d1.d_year = 1999 + AND d1.d_date_sk = ss_sold_date_sk + AND i_item_sk = ss_item_sk + AND s_store_sk = ss_store_sk + AND ss_customer_sk = sr_customer_sk + AND ss_item_sk = sr_item_sk + AND ss_ticket_number = sr_ticket_number + AND sr_returned_date_sk = d2.d_date_sk + AND d2.d_moy BETWEEN 9 AND 9 + 3 + AND d2.d_year = 1999 + AND sr_customer_sk = cs_bill_customer_sk + AND sr_item_sk = cs_item_sk + AND cs_sold_date_sk = d3.d_date_sk + AND d3.d_year IN (1999, 1999 + 1, 1999 + 2) +GROUP BY + i_item_id, i_item_desc, s_store_id, s_store_name +ORDER BY + i_item_id, i_item_desc, s_store_id, s_store_name +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q3.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q3.sql new file mode 100755 index 0000000000000000000000000000000000000000..181509df9deb7db5bca35835746d1504e08d7839 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q3.sql @@ -0,0 +1,13 @@ +SELECT + dt.d_year, + item.i_brand_id brand_id, + item.i_brand brand, + SUM(ss_ext_sales_price) sum_agg +FROM date_dim dt, store_sales, item +WHERE dt.d_date_sk = store_sales.ss_sold_date_sk + AND store_sales.ss_item_sk = item.i_item_sk + AND item.i_manufact_id = 128 + AND dt.d_moy = 11 +GROUP BY dt.d_year, item.i_brand, item.i_brand_id +ORDER BY dt.d_year, sum_agg DESC, brand_id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q30.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q30.sql new file mode 100755 index 0000000000000000000000000000000000000000..986bef566d2c866016b38add326460e5e1050e51 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q30.sql @@ -0,0 +1,35 @@ +WITH customer_total_return AS +(SELECT + wr_returning_customer_sk AS ctr_customer_sk, + ca_state AS ctr_state, + sum(wr_return_amt) AS ctr_total_return + FROM web_returns, date_dim, customer_address + WHERE wr_returned_date_sk = d_date_sk + AND d_year = 2002 + AND wr_returning_addr_sk = ca_address_sk + GROUP BY wr_returning_customer_sk, ca_state) +SELECT + c_customer_id, + c_salutation, + c_first_name, + c_last_name, + c_preferred_cust_flag, + c_birth_day, + c_birth_month, + c_birth_year, + c_birth_country, + c_login, + c_email_address, + c_last_review_date, + ctr_total_return +FROM customer_total_return ctr1, customer_address, customer +WHERE ctr1.ctr_total_return > (SELECT avg(ctr_total_return) * 1.2 +FROM customer_total_return ctr2 +WHERE ctr1.ctr_state = ctr2.ctr_state) + AND ca_address_sk = c_current_addr_sk + AND ca_state = 'GA' + AND ctr1.ctr_customer_sk = c_customer_sk +ORDER BY c_customer_id, c_salutation, c_first_name, c_last_name, c_preferred_cust_flag + , c_birth_day, c_birth_month, c_birth_year, c_birth_country, c_login, c_email_address + , c_last_review_date, ctr_total_return +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q31.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q31.sql new file mode 100755 index 0000000000000000000000000000000000000000..3e543d54364072eaeed7e23b3d062f1728e7d974 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q31.sql @@ -0,0 +1,60 @@ +WITH ss AS +(SELECT + ca_county, + d_qoy, + d_year, + sum(ss_ext_sales_price) AS store_sales + FROM store_sales, date_dim, customer_address + WHERE ss_sold_date_sk = d_date_sk + AND ss_addr_sk = ca_address_sk + GROUP BY ca_county, d_qoy, d_year), + ws AS + (SELECT + ca_county, + d_qoy, + d_year, + sum(ws_ext_sales_price) AS web_sales + FROM web_sales, date_dim, customer_address + WHERE ws_sold_date_sk = d_date_sk + AND ws_bill_addr_sk = ca_address_sk + GROUP BY ca_county, d_qoy, d_year) +SELECT + ss1.ca_county, + ss1.d_year, + ws2.web_sales / ws1.web_sales web_q1_q2_increase, + ss2.store_sales / ss1.store_sales store_q1_q2_increase, + ws3.web_sales / ws2.web_sales web_q2_q3_increase, + ss3.store_sales / ss2.store_sales store_q2_q3_increase +FROM + ss ss1, ss ss2, ss ss3, ws ws1, ws ws2, ws ws3 +WHERE + ss1.d_qoy = 1 + AND ss1.d_year = 2000 + AND ss1.ca_county = ss2.ca_county + AND ss2.d_qoy = 2 + AND ss2.d_year = 2000 + AND ss2.ca_county = ss3.ca_county + AND ss3.d_qoy = 3 + AND ss3.d_year = 2000 + AND ss1.ca_county = ws1.ca_county + AND ws1.d_qoy = 1 + AND ws1.d_year = 2000 + AND ws1.ca_county = ws2.ca_county + AND ws2.d_qoy = 2 + AND ws2.d_year = 2000 + AND ws1.ca_county = ws3.ca_county + AND ws3.d_qoy = 3 + AND ws3.d_year = 2000 + AND CASE WHEN ws1.web_sales > 0 + THEN ws2.web_sales / ws1.web_sales + ELSE NULL END + > CASE WHEN ss1.store_sales > 0 + THEN ss2.store_sales / ss1.store_sales + ELSE NULL END + AND CASE WHEN ws2.web_sales > 0 + THEN ws3.web_sales / ws2.web_sales + ELSE NULL END + > CASE WHEN ss2.store_sales > 0 + THEN ss3.store_sales / ss2.store_sales + ELSE NULL END +ORDER BY ss1.ca_county diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q32.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q32.sql new file mode 100755 index 0000000000000000000000000000000000000000..1d856ca5230450be868918425e34cc8dfdf3cb92 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q32.sql @@ -0,0 +1,15 @@ +SELECT sum(cs_ext_discount_amt) AS `excess discount amount` +FROM + catalog_sales, item, date_dim +WHERE + i_manufact_id = 977 + AND i_item_sk = cs_item_sk + AND d_date BETWEEN '2000-01-27' AND (cast('2000-01-27' AS DATE) + interval 90 days) + AND d_date_sk = cs_sold_date_sk + AND cs_ext_discount_amt > ( + SELECT 1.3 * avg(cs_ext_discount_amt) + FROM catalog_sales, date_dim + WHERE cs_item_sk = i_item_sk + AND d_date BETWEEN '2000-01-27' AND (cast('2000-01-27' AS DATE) + interval 90 days) + AND d_date_sk = cs_sold_date_sk) +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q33.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q33.sql new file mode 100755 index 0000000000000000000000000000000000000000..d24856aa5c1eb7417640421f7a1b73ac99e015e8 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q33.sql @@ -0,0 +1,65 @@ +WITH ss AS ( + SELECT + i_manufact_id, + sum(ss_ext_sales_price) total_sales + FROM + store_sales, date_dim, customer_address, item + WHERE + i_manufact_id IN (SELECT i_manufact_id + FROM item + WHERE i_category IN ('Electronics')) + AND ss_item_sk = i_item_sk + AND ss_sold_date_sk = d_date_sk + AND d_year = 1998 + AND d_moy = 5 + AND ss_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_manufact_id), cs AS +(SELECT + i_manufact_id, + sum(cs_ext_sales_price) total_sales + FROM catalog_sales, date_dim, customer_address, item + WHERE + i_manufact_id IN ( + SELECT i_manufact_id + FROM item + WHERE + i_category IN ('Electronics')) + AND cs_item_sk = i_item_sk + AND cs_sold_date_sk = d_date_sk + AND d_year = 1998 + AND d_moy = 5 + AND cs_bill_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_manufact_id), + ws AS ( + SELECT + i_manufact_id, + sum(ws_ext_sales_price) total_sales + FROM + web_sales, date_dim, customer_address, item + WHERE + i_manufact_id IN (SELECT i_manufact_id + FROM item + WHERE i_category IN ('Electronics')) + AND ws_item_sk = i_item_sk + AND ws_sold_date_sk = d_date_sk + AND d_year = 1998 + AND d_moy = 5 + AND ws_bill_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_manufact_id) +SELECT + i_manufact_id, + sum(total_sales) total_sales +FROM (SELECT * + FROM ss + UNION ALL + SELECT * + FROM cs + UNION ALL + SELECT * + FROM ws) tmp1 +GROUP BY i_manufact_id +ORDER BY total_sales +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q34.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q34.sql new file mode 100755 index 0000000000000000000000000000000000000000..33396bf16e574390fb9abcd6b3fb657fdb0131ec --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q34.sql @@ -0,0 +1,32 @@ +SELECT + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +FROM + (SELECT + ss_ticket_number, + ss_customer_sk, + count(*) cnt + FROM store_sales, date_dim, store, household_demographics + WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_store_sk = store.s_store_sk + AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + AND (date_dim.d_dom BETWEEN 1 AND 3 OR date_dim.d_dom BETWEEN 25 AND 28) + AND (household_demographics.hd_buy_potential = '>10000' OR + household_demographics.hd_buy_potential = 'unknown') + AND household_demographics.hd_vehicle_count > 0 + AND (CASE WHEN household_demographics.hd_vehicle_count > 0 + THEN household_demographics.hd_dep_count / household_demographics.hd_vehicle_count + ELSE NULL + END) > 1.2 + AND date_dim.d_year IN (1999, 1999 + 1, 1999 + 2) + AND store.s_county IN + ('Williamson County', 'Williamson County', 'Williamson County', 'Williamson County', + 'Williamson County', 'Williamson County', 'Williamson County', 'Williamson County') + GROUP BY ss_ticket_number, ss_customer_sk) dn, customer +WHERE ss_customer_sk = c_customer_sk + AND cnt BETWEEN 15 AND 20 +ORDER BY c_last_name, c_first_name, c_salutation, c_preferred_cust_flag DESC diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q35.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q35.sql new file mode 100755 index 0000000000000000000000000000000000000000..cfe4342d8be865a420585a1f8580db3f15925c07 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q35.sql @@ -0,0 +1,46 @@ +SELECT + ca_state, + cd_gender, + cd_marital_status, + count(*) cnt1, + min(cd_dep_count), + max(cd_dep_count), + avg(cd_dep_count), + cd_dep_employed_count, + count(*) cnt2, + min(cd_dep_employed_count), + max(cd_dep_employed_count), + avg(cd_dep_employed_count), + cd_dep_college_count, + count(*) cnt3, + min(cd_dep_college_count), + max(cd_dep_college_count), + avg(cd_dep_college_count) +FROM + customer c, customer_address ca, customer_demographics +WHERE + c.c_current_addr_sk = ca.ca_address_sk AND + cd_demo_sk = c.c_current_cdemo_sk AND + exists(SELECT * + FROM store_sales, date_dim + WHERE c.c_customer_sk = ss_customer_sk AND + ss_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4) AND + (exists(SELECT * + FROM web_sales, date_dim + WHERE c.c_customer_sk = ws_bill_customer_sk AND + ws_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4) OR + exists(SELECT * + FROM catalog_sales, date_dim + WHERE c.c_customer_sk = cs_ship_customer_sk AND + cs_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4)) +GROUP BY ca_state, cd_gender, cd_marital_status, cd_dep_count, + cd_dep_employed_count, cd_dep_college_count +ORDER BY ca_state, cd_gender, cd_marital_status, cd_dep_count, + cd_dep_employed_count, cd_dep_college_count +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q36.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q36.sql new file mode 100755 index 0000000000000000000000000000000000000000..a8f93df76a34b01f22c8753295c5745f9671d90e --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q36.sql @@ -0,0 +1,26 @@ +SELECT + sum(ss_net_profit) / sum(ss_ext_sales_price) AS gross_margin, + i_category, + i_class, + grouping(i_category) + grouping(i_class) AS lochierarchy, + rank() + OVER ( + PARTITION BY grouping(i_category) + grouping(i_class), + CASE WHEN grouping(i_class) = 0 + THEN i_category END + ORDER BY sum(ss_net_profit) / sum(ss_ext_sales_price) ASC) AS rank_within_parent +FROM + store_sales, date_dim d1, item, store +WHERE + d1.d_year = 2001 + AND d1.d_date_sk = ss_sold_date_sk + AND i_item_sk = ss_item_sk + AND s_store_sk = ss_store_sk + AND s_state IN ('TN', 'TN', 'TN', 'TN', 'TN', 'TN', 'TN', 'TN') +GROUP BY ROLLUP (i_category, i_class) +ORDER BY + lochierarchy DESC + , CASE WHEN lochierarchy = 0 + THEN i_category END + , rank_within_parent +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q37.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q37.sql new file mode 100755 index 0000000000000000000000000000000000000000..11b3821fa48b8737710abaf4b30882d0083e0dd0 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q37.sql @@ -0,0 +1,15 @@ +SELECT + i_item_id, + i_item_desc, + i_current_price +FROM item, inventory, date_dim, catalog_sales +WHERE i_current_price BETWEEN 68 AND 68 + 30 + AND inv_item_sk = i_item_sk + AND d_date_sk = inv_date_sk + AND d_date BETWEEN cast('2000-02-01' AS DATE) AND (cast('2000-02-01' AS DATE) + INTERVAL 60 days) + AND i_manufact_id IN (677, 940, 694, 808) + AND inv_quantity_on_hand BETWEEN 100 AND 500 + AND cs_item_sk = i_item_sk +GROUP BY i_item_id, i_item_desc, i_current_price +ORDER BY i_item_id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q38.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q38.sql new file mode 100755 index 0000000000000000000000000000000000000000..1c8d53ee2bbfc04fb392c430a6b80a11e05bdb5d --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q38.sql @@ -0,0 +1,30 @@ +SELECT count(*) +FROM ( + SELECT DISTINCT + c_last_name, + c_first_name, + d_date + FROM store_sales, date_dim, customer + WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_customer_sk = customer.c_customer_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + INTERSECT + SELECT DISTINCT + c_last_name, + c_first_name, + d_date + FROM catalog_sales, date_dim, customer + WHERE catalog_sales.cs_sold_date_sk = date_dim.d_date_sk + AND catalog_sales.cs_bill_customer_sk = customer.c_customer_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + INTERSECT + SELECT DISTINCT + c_last_name, + c_first_name, + d_date + FROM web_sales, date_dim, customer + WHERE web_sales.ws_sold_date_sk = date_dim.d_date_sk + AND web_sales.ws_bill_customer_sk = customer.c_customer_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + ) hot_cust +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q39a.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q39a.sql new file mode 100755 index 0000000000000000000000000000000000000000..9fc4c1701cf211c2fa6f53ecc0138cf9e1f0c9e1 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q39a.sql @@ -0,0 +1,47 @@ +WITH inv AS +(SELECT + w_warehouse_name, + w_warehouse_sk, + i_item_sk, + d_moy, + stdev, + mean, + CASE mean + WHEN 0 + THEN NULL + ELSE stdev / mean END cov + FROM (SELECT + w_warehouse_name, + w_warehouse_sk, + i_item_sk, + d_moy, + stddev_samp(inv_quantity_on_hand) stdev, + avg(inv_quantity_on_hand) mean + FROM inventory, item, warehouse, date_dim + WHERE inv_item_sk = i_item_sk + AND inv_warehouse_sk = w_warehouse_sk + AND inv_date_sk = d_date_sk + AND d_year = 2001 + GROUP BY w_warehouse_name, w_warehouse_sk, i_item_sk, d_moy) foo + WHERE CASE mean + WHEN 0 + THEN 0 + ELSE stdev / mean END > 1) +SELECT + inv1.w_warehouse_sk, + inv1.i_item_sk, + inv1.d_moy, + inv1.mean, + inv1.cov, + inv2.w_warehouse_sk, + inv2.i_item_sk, + inv2.d_moy, + inv2.mean, + inv2.cov +FROM inv inv1, inv inv2 +WHERE inv1.i_item_sk = inv2.i_item_sk + AND inv1.w_warehouse_sk = inv2.w_warehouse_sk + AND inv1.d_moy = 1 + AND inv2.d_moy = 1 + 1 +ORDER BY inv1.w_warehouse_sk, inv1.i_item_sk, inv1.d_moy, inv1.mean, inv1.cov + , inv2.d_moy, inv2.mean, inv2.cov diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q39b.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q39b.sql new file mode 100755 index 0000000000000000000000000000000000000000..6f8493029fab4d981ea9ae01502e61ff5abd1557 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q39b.sql @@ -0,0 +1,48 @@ +WITH inv AS +(SELECT + w_warehouse_name, + w_warehouse_sk, + i_item_sk, + d_moy, + stdev, + mean, + CASE mean + WHEN 0 + THEN NULL + ELSE stdev / mean END cov + FROM (SELECT + w_warehouse_name, + w_warehouse_sk, + i_item_sk, + d_moy, + stddev_samp(inv_quantity_on_hand) stdev, + avg(inv_quantity_on_hand) mean + FROM inventory, item, warehouse, date_dim + WHERE inv_item_sk = i_item_sk + AND inv_warehouse_sk = w_warehouse_sk + AND inv_date_sk = d_date_sk + AND d_year = 2001 + GROUP BY w_warehouse_name, w_warehouse_sk, i_item_sk, d_moy) foo + WHERE CASE mean + WHEN 0 + THEN 0 + ELSE stdev / mean END > 1) +SELECT + inv1.w_warehouse_sk, + inv1.i_item_sk, + inv1.d_moy, + inv1.mean, + inv1.cov, + inv2.w_warehouse_sk, + inv2.i_item_sk, + inv2.d_moy, + inv2.mean, + inv2.cov +FROM inv inv1, inv inv2 +WHERE inv1.i_item_sk = inv2.i_item_sk + AND inv1.w_warehouse_sk = inv2.w_warehouse_sk + AND inv1.d_moy = 1 + AND inv2.d_moy = 1 + 1 + AND inv1.cov > 1.5 +ORDER BY inv1.w_warehouse_sk, inv1.i_item_sk, inv1.d_moy, inv1.mean, inv1.cov + , inv2.d_moy, inv2.mean, inv2.cov diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q4.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q4.sql new file mode 100755 index 0000000000000000000000000000000000000000..b9f27fbc9a4a6855e16026bc4dff46538c62b488 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q4.sql @@ -0,0 +1,120 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(((ss_ext_list_price - ss_ext_wholesale_cost - ss_ext_discount_amt) + + ss_ext_sales_price) / 2) year_total, + 's' sale_type + FROM customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk AND ss_sold_date_sk = d_date_sk + GROUP BY c_customer_id, + c_first_name, + c_last_name, + c_preferred_cust_flag, + c_birth_country, + c_login, + c_email_address, + d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum((((cs_ext_list_price - cs_ext_wholesale_cost - cs_ext_discount_amt) + + cs_ext_sales_price) / 2)) year_total, + 'c' sale_type + FROM customer, catalog_sales, date_dim + WHERE c_customer_sk = cs_bill_customer_sk AND cs_sold_date_sk = d_date_sk + GROUP BY c_customer_id, + c_first_name, + c_last_name, + c_preferred_cust_flag, + c_birth_country, + c_login, + c_email_address, + d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum((((ws_ext_list_price - ws_ext_wholesale_cost - ws_ext_discount_amt) + ws_ext_sales_price) / + 2)) year_total, + 'w' sale_type + FROM customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk AND ws_sold_date_sk = d_date_sk + GROUP BY c_customer_id, + c_first_name, + c_last_name, + c_preferred_cust_flag, + c_birth_country, + c_login, + c_email_address, + d_year) +SELECT + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name, + t_s_secyear.customer_preferred_cust_flag, + t_s_secyear.customer_birth_country, + t_s_secyear.customer_login, + t_s_secyear.customer_email_address +FROM year_total t_s_firstyear, year_total t_s_secyear, year_total t_c_firstyear, + year_total t_c_secyear, year_total t_w_firstyear, year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_c_secyear.customer_id + AND t_s_firstyear.customer_id = t_c_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_c_firstyear.sale_type = 'c' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_c_secyear.sale_type = 'c' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.dyear = 2001 + AND t_s_secyear.dyear = 2001 + 1 + AND t_c_firstyear.dyear = 2001 + AND t_c_secyear.dyear = 2001 + 1 + AND t_w_firstyear.dyear = 2001 + AND t_w_secyear.dyear = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_c_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_c_firstyear.year_total > 0 + THEN t_c_secyear.year_total / t_c_firstyear.year_total + ELSE NULL END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + ELSE NULL END + AND CASE WHEN t_c_firstyear.year_total > 0 + THEN t_c_secyear.year_total / t_c_firstyear.year_total + ELSE NULL END + > CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + ELSE NULL END +ORDER BY + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name, + t_s_secyear.customer_preferred_cust_flag, + t_s_secyear.customer_birth_country, + t_s_secyear.customer_login, + t_s_secyear.customer_email_address +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q40.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q40.sql new file mode 100755 index 0000000000000000000000000000000000000000..66d8b73ac1c1510b00fc0948b7b3ffb927618680 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q40.sql @@ -0,0 +1,25 @@ +SELECT + w_state, + i_item_id, + sum(CASE WHEN (cast(d_date AS DATE) < cast('2000-03-11' AS DATE)) + THEN cs_sales_price - coalesce(cr_refunded_cash, 0) + ELSE 0 END) AS sales_before, + sum(CASE WHEN (cast(d_date AS DATE) >= cast('2000-03-11' AS DATE)) + THEN cs_sales_price - coalesce(cr_refunded_cash, 0) + ELSE 0 END) AS sales_after +FROM + catalog_sales + LEFT OUTER JOIN catalog_returns ON + (cs_order_number = cr_order_number + AND cs_item_sk = cr_item_sk) + , warehouse, item, date_dim +WHERE + i_current_price BETWEEN 0.99 AND 1.49 + AND i_item_sk = cs_item_sk + AND cs_warehouse_sk = w_warehouse_sk + AND cs_sold_date_sk = d_date_sk + AND d_date BETWEEN (cast('2000-03-11' AS DATE) - INTERVAL 30 days) + AND (cast('2000-03-11' AS DATE) + INTERVAL 30 days) +GROUP BY w_state, i_item_id +ORDER BY w_state, i_item_id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q41.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q41.sql new file mode 100755 index 0000000000000000000000000000000000000000..25e317e0e201a74e5911b195d3cb0de15c254ccb --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q41.sql @@ -0,0 +1,49 @@ +SELECT DISTINCT (i_product_name) +FROM item i1 +WHERE i_manufact_id BETWEEN 738 AND 738 + 40 + AND (SELECT count(*) AS item_cnt +FROM item +WHERE (i_manufact = i1.i_manufact AND + ((i_category = 'Women' AND + (i_color = 'powder' OR i_color = 'khaki') AND + (i_units = 'Ounce' OR i_units = 'Oz') AND + (i_size = 'medium' OR i_size = 'extra large') + ) OR + (i_category = 'Women' AND + (i_color = 'brown' OR i_color = 'honeydew') AND + (i_units = 'Bunch' OR i_units = 'Ton') AND + (i_size = 'N/A' OR i_size = 'small') + ) OR + (i_category = 'Men' AND + (i_color = 'floral' OR i_color = 'deep') AND + (i_units = 'N/A' OR i_units = 'Dozen') AND + (i_size = 'petite' OR i_size = 'large') + ) OR + (i_category = 'Men' AND + (i_color = 'light' OR i_color = 'cornflower') AND + (i_units = 'Box' OR i_units = 'Pound') AND + (i_size = 'medium' OR i_size = 'extra large') + ))) OR + (i_manufact = i1.i_manufact AND + ((i_category = 'Women' AND + (i_color = 'midnight' OR i_color = 'snow') AND + (i_units = 'Pallet' OR i_units = 'Gross') AND + (i_size = 'medium' OR i_size = 'extra large') + ) OR + (i_category = 'Women' AND + (i_color = 'cyan' OR i_color = 'papaya') AND + (i_units = 'Cup' OR i_units = 'Dram') AND + (i_size = 'N/A' OR i_size = 'small') + ) OR + (i_category = 'Men' AND + (i_color = 'orange' OR i_color = 'frosted') AND + (i_units = 'Each' OR i_units = 'Tbl') AND + (i_size = 'petite' OR i_size = 'large') + ) OR + (i_category = 'Men' AND + (i_color = 'forest' OR i_color = 'ghost') AND + (i_units = 'Lb' OR i_units = 'Bundle') AND + (i_size = 'medium' OR i_size = 'extra large') + )))) > 0 +ORDER BY i_product_name +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q42.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q42.sql new file mode 100755 index 0000000000000000000000000000000000000000..4d2e71760d8701c2235ff85471c3953188e97f95 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q42.sql @@ -0,0 +1,18 @@ +SELECT + dt.d_year, + item.i_category_id, + item.i_category, + sum(ss_ext_sales_price) +FROM date_dim dt, store_sales, item +WHERE dt.d_date_sk = store_sales.ss_sold_date_sk + AND store_sales.ss_item_sk = item.i_item_sk + AND item.i_manager_id = 1 + AND dt.d_moy = 11 + AND dt.d_year = 2000 +GROUP BY dt.d_year + , item.i_category_id + , item.i_category +ORDER BY sum(ss_ext_sales_price) DESC, dt.d_year + , item.i_category_id + , item.i_category +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q43.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q43.sql new file mode 100755 index 0000000000000000000000000000000000000000..45411772c1b54322567d808a736948d1a961f460 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q43.sql @@ -0,0 +1,33 @@ +SELECT + s_store_name, + s_store_id, + sum(CASE WHEN (d_day_name = 'Sunday') + THEN ss_sales_price + ELSE NULL END) sun_sales, + sum(CASE WHEN (d_day_name = 'Monday') + THEN ss_sales_price + ELSE NULL END) mon_sales, + sum(CASE WHEN (d_day_name = 'Tuesday') + THEN ss_sales_price + ELSE NULL END) tue_sales, + sum(CASE WHEN (d_day_name = 'Wednesday') + THEN ss_sales_price + ELSE NULL END) wed_sales, + sum(CASE WHEN (d_day_name = 'Thursday') + THEN ss_sales_price + ELSE NULL END) thu_sales, + sum(CASE WHEN (d_day_name = 'Friday') + THEN ss_sales_price + ELSE NULL END) fri_sales, + sum(CASE WHEN (d_day_name = 'Saturday') + THEN ss_sales_price + ELSE NULL END) sat_sales +FROM date_dim, store_sales, store +WHERE d_date_sk = ss_sold_date_sk AND + s_store_sk = ss_store_sk AND + s_gmt_offset = -5 AND + d_year = 2000 +GROUP BY s_store_name, s_store_id +ORDER BY s_store_name, s_store_id, sun_sales, mon_sales, tue_sales, wed_sales, + thu_sales, fri_sales, sat_sales +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q44.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q44.sql new file mode 100755 index 0000000000000000000000000000000000000000..379e604788625c31e910ee261a45de84ec3286c2 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q44.sql @@ -0,0 +1,46 @@ +SELECT + asceding.rnk, + i1.i_product_name best_performing, + i2.i_product_name worst_performing +FROM (SELECT * +FROM (SELECT + item_sk, + rank() + OVER ( + ORDER BY rank_col ASC) rnk +FROM (SELECT + ss_item_sk item_sk, + avg(ss_net_profit) rank_col +FROM store_sales ss1 +WHERE ss_store_sk = 4 +GROUP BY ss_item_sk +HAVING avg(ss_net_profit) > 0.9 * (SELECT avg(ss_net_profit) rank_col +FROM store_sales +WHERE ss_store_sk = 4 + AND ss_addr_sk IS NULL +GROUP BY ss_store_sk)) V1) V11 +WHERE rnk < 11) asceding, + (SELECT * + FROM (SELECT + item_sk, + rank() + OVER ( + ORDER BY rank_col DESC) rnk + FROM (SELECT + ss_item_sk item_sk, + avg(ss_net_profit) rank_col + FROM store_sales ss1 + WHERE ss_store_sk = 4 + GROUP BY ss_item_sk + HAVING avg(ss_net_profit) > 0.9 * (SELECT avg(ss_net_profit) rank_col + FROM store_sales + WHERE ss_store_sk = 4 + AND ss_addr_sk IS NULL + GROUP BY ss_store_sk)) V2) V21 + WHERE rnk < 11) descending, + item i1, item i2 +WHERE asceding.rnk = descending.rnk + AND i1.i_item_sk = asceding.item_sk + AND i2.i_item_sk = descending.item_sk +ORDER BY asceding.rnk +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q45.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q45.sql new file mode 100755 index 0000000000000000000000000000000000000000..907438f196c4c9403a6d52f0a1e9c175d0866a81 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q45.sql @@ -0,0 +1,21 @@ +SELECT + ca_zip, + ca_city, + sum(ws_sales_price) +FROM web_sales, customer, customer_address, date_dim, item +WHERE ws_bill_customer_sk = c_customer_sk + AND c_current_addr_sk = ca_address_sk + AND ws_item_sk = i_item_sk + AND (substr(ca_zip, 1, 5) IN + ('85669', '86197', '88274', '83405', '86475', '85392', '85460', '80348', '81792') + OR + i_item_id IN (SELECT i_item_id + FROM item + WHERE i_item_sk IN (2, 3, 5, 7, 11, 13, 17, 19, 23, 29) + ) +) + AND ws_sold_date_sk = d_date_sk + AND d_qoy = 2 AND d_year = 2001 +GROUP BY ca_zip, ca_city +ORDER BY ca_zip, ca_city +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q46.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q46.sql new file mode 100755 index 0000000000000000000000000000000000000000..0911677dff206b72d72f2a95a28226174b102be7 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q46.sql @@ -0,0 +1,32 @@ +SELECT + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number, + amt, + profit +FROM + (SELECT + ss_ticket_number, + ss_customer_sk, + ca_city bought_city, + sum(ss_coupon_amt) amt, + sum(ss_net_profit) profit + FROM store_sales, date_dim, store, household_demographics, customer_address + WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_store_sk = store.s_store_sk + AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + AND store_sales.ss_addr_sk = customer_address.ca_address_sk + AND (household_demographics.hd_dep_count = 4 OR + household_demographics.hd_vehicle_count = 3) + AND date_dim.d_dow IN (6, 0) + AND date_dim.d_year IN (1999, 1999 + 1, 1999 + 2) + AND store.s_city IN ('Fairview', 'Midway', 'Fairview', 'Fairview', 'Fairview') + GROUP BY ss_ticket_number, ss_customer_sk, ss_addr_sk, ca_city) dn, customer, + customer_address current_addr +WHERE ss_customer_sk = c_customer_sk + AND customer.c_current_addr_sk = current_addr.ca_address_sk + AND current_addr.ca_city <> bought_city +ORDER BY c_last_name, c_first_name, ca_city, bought_city, ss_ticket_number +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q47.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q47.sql new file mode 100755 index 0000000000000000000000000000000000000000..cfc37a4cece667b53bc1ebe07d38a68470ad78be --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q47.sql @@ -0,0 +1,63 @@ +WITH v1 AS ( + SELECT + i_category, + i_brand, + s_store_name, + s_company_name, + d_year, + d_moy, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) + OVER + (PARTITION BY i_category, i_brand, + s_store_name, s_company_name, d_year) + avg_monthly_sales, + rank() + OVER + (PARTITION BY i_category, i_brand, + s_store_name, s_company_name + ORDER BY d_year, d_moy) rn + FROM item, store_sales, date_dim, store + WHERE ss_item_sk = i_item_sk AND + ss_sold_date_sk = d_date_sk AND + ss_store_sk = s_store_sk AND + ( + d_year = 1999 OR + (d_year = 1999 - 1 AND d_moy = 12) OR + (d_year = 1999 + 1 AND d_moy = 1) + ) + GROUP BY i_category, i_brand, + s_store_name, s_company_name, + d_year, d_moy), + v2 AS ( + SELECT + v1.i_category, + v1.i_brand, + v1.s_store_name, + v1.s_company_name, + v1.d_year, + v1.d_moy, + v1.avg_monthly_sales, + v1.sum_sales, + v1_lag.sum_sales psum, + v1_lead.sum_sales nsum + FROM v1, v1 v1_lag, v1 v1_lead + WHERE v1.i_category = v1_lag.i_category AND + v1.i_category = v1_lead.i_category AND + v1.i_brand = v1_lag.i_brand AND + v1.i_brand = v1_lead.i_brand AND + v1.s_store_name = v1_lag.s_store_name AND + v1.s_store_name = v1_lead.s_store_name AND + v1.s_company_name = v1_lag.s_company_name AND + v1.s_company_name = v1_lead.s_company_name AND + v1.rn = v1_lag.rn + 1 AND + v1.rn = v1_lead.rn - 1) +SELECT * +FROM v2 +WHERE d_year = 1999 AND + avg_monthly_sales > 0 AND + CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, 3 +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q48.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q48.sql new file mode 100755 index 0000000000000000000000000000000000000000..fdb9f38e294f7ffe08f1e2b32fb72c10486a1587 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q48.sql @@ -0,0 +1,63 @@ +SELECT sum(ss_quantity) +FROM store_sales, store, customer_demographics, customer_address, date_dim +WHERE s_store_sk = ss_store_sk + AND ss_sold_date_sk = d_date_sk AND d_year = 2001 + AND + ( + ( + cd_demo_sk = ss_cdemo_sk + AND + cd_marital_status = 'M' + AND + cd_education_status = '4 yr Degree' + AND + ss_sales_price BETWEEN 100.00 AND 150.00 + ) + OR + ( + cd_demo_sk = ss_cdemo_sk + AND + cd_marital_status = 'D' + AND + cd_education_status = '2 yr Degree' + AND + ss_sales_price BETWEEN 50.00 AND 100.00 + ) + OR + ( + cd_demo_sk = ss_cdemo_sk + AND + cd_marital_status = 'S' + AND + cd_education_status = 'College' + AND + ss_sales_price BETWEEN 150.00 AND 200.00 + ) + ) + AND + ( + ( + ss_addr_sk = ca_address_sk + AND + ca_country = 'United States' + AND + ca_state IN ('CO', 'OH', 'TX') + AND ss_net_profit BETWEEN 0 AND 2000 + ) + OR + (ss_addr_sk = ca_address_sk + AND + ca_country = 'United States' + AND + ca_state IN ('OR', 'MN', 'KY') + AND ss_net_profit BETWEEN 150 AND 3000 + ) + OR + (ss_addr_sk = ca_address_sk + AND + ca_country = 'United States' + AND + ca_state IN ('VA', 'CA', 'MS') + AND ss_net_profit BETWEEN 50 AND 25000 + ) + ) diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q49.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q49.sql new file mode 100755 index 0000000000000000000000000000000000000000..9568d8b92d10a10dd417a99565df636cb1d07569 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q49.sql @@ -0,0 +1,126 @@ +SELECT + 'web' AS channel, + web.item, + web.return_ratio, + web.return_rank, + web.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + ws.ws_item_sk AS item, + (cast(sum(coalesce(wr.wr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(ws.ws_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(wr.wr_return_amt, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(ws.ws_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + web_sales ws LEFT OUTER JOIN web_returns wr + ON (ws.ws_order_number = wr.wr_order_number AND + ws.ws_item_sk = wr.wr_item_sk) + , date_dim + WHERE + wr.wr_return_amt > 10000 + AND ws.ws_net_profit > 1 + AND ws.ws_net_paid > 0 + AND ws.ws_quantity > 0 + AND ws_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY ws.ws_item_sk + ) in_web + ) web +WHERE (web.return_rank <= 10 OR web.currency_rank <= 10) +UNION +SELECT + 'catalog' AS channel, + catalog.item, + catalog.return_ratio, + catalog.return_rank, + catalog.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + cs.cs_item_sk AS item, + (cast(sum(coalesce(cr.cr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(cs.cs_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(cr.cr_return_amount, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(cs.cs_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + catalog_sales cs LEFT OUTER JOIN catalog_returns cr + ON (cs.cs_order_number = cr.cr_order_number AND + cs.cs_item_sk = cr.cr_item_sk) + , date_dim + WHERE + cr.cr_return_amount > 10000 + AND cs.cs_net_profit > 1 + AND cs.cs_net_paid > 0 + AND cs.cs_quantity > 0 + AND cs_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY cs.cs_item_sk + ) in_cat + ) catalog +WHERE (catalog.return_rank <= 10 OR catalog.currency_rank <= 10) +UNION +SELECT + 'store' AS channel, + store.item, + store.return_ratio, + store.return_rank, + store.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + sts.ss_item_sk AS item, + (cast(sum(coalesce(sr.sr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(sts.ss_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(sr.sr_return_amt, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(sts.ss_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + store_sales sts LEFT OUTER JOIN store_returns sr + ON (sts.ss_ticket_number = sr.sr_ticket_number AND sts.ss_item_sk = sr.sr_item_sk) + , date_dim + WHERE + sr.sr_return_amt > 10000 + AND sts.ss_net_profit > 1 + AND sts.ss_net_paid > 0 + AND sts.ss_quantity > 0 + AND ss_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY sts.ss_item_sk + ) in_store + ) store +WHERE (store.return_rank <= 10 OR store.currency_rank <= 10) +ORDER BY 1, 4, 5 +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q5.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q5.sql new file mode 100755 index 0000000000000000000000000000000000000000..b87cf3a44827b1757e12f2f6186fe87db7c46d9b --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q5.sql @@ -0,0 +1,131 @@ +WITH ssr AS +( SELECT + s_store_id, + sum(sales_price) AS sales, + sum(profit) AS profit, + sum(return_amt) AS RETURNS, + sum(net_loss) AS profit_loss + FROM + (SELECT + ss_store_sk AS store_sk, + ss_sold_date_sk AS date_sk, + ss_ext_sales_price AS sales_price, + ss_net_profit AS profit, + cast(0 AS DECIMAL(7, 2)) AS return_amt, + cast(0 AS DECIMAL(7, 2)) AS net_loss + FROM store_sales + UNION ALL + SELECT + sr_store_sk AS store_sk, + sr_returned_date_sk AS date_sk, + cast(0 AS DECIMAL(7, 2)) AS sales_price, + cast(0 AS DECIMAL(7, 2)) AS profit, + sr_return_amt AS return_amt, + sr_net_loss AS net_loss + FROM store_returns) + salesreturns, date_dim, store + WHERE date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-23' AS DATE) + AND ((cast('2000-08-23' AS DATE) + INTERVAL 14 days)) + AND store_sk = s_store_sk + GROUP BY s_store_id), + csr AS + ( SELECT + cp_catalog_page_id, + sum(sales_price) AS sales, + sum(profit) AS profit, + sum(return_amt) AS RETURNS, + sum(net_loss) AS profit_loss + FROM + (SELECT + cs_catalog_page_sk AS page_sk, + cs_sold_date_sk AS date_sk, + cs_ext_sales_price AS sales_price, + cs_net_profit AS profit, + cast(0 AS DECIMAL(7, 2)) AS return_amt, + cast(0 AS DECIMAL(7, 2)) AS net_loss + FROM catalog_sales + UNION ALL + SELECT + cr_catalog_page_sk AS page_sk, + cr_returned_date_sk AS date_sk, + cast(0 AS DECIMAL(7, 2)) AS sales_price, + cast(0 AS DECIMAL(7, 2)) AS profit, + cr_return_amount AS return_amt, + cr_net_loss AS net_loss + FROM catalog_returns + ) salesreturns, date_dim, catalog_page + WHERE date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-23' AS DATE) + AND ((cast('2000-08-23' AS DATE) + INTERVAL 14 days)) + AND page_sk = cp_catalog_page_sk + GROUP BY cp_catalog_page_id) + , + wsr AS + ( SELECT + web_site_id, + sum(sales_price) AS sales, + sum(profit) AS profit, + sum(return_amt) AS RETURNS, + sum(net_loss) AS profit_loss + FROM + (SELECT + ws_web_site_sk AS wsr_web_site_sk, + ws_sold_date_sk AS date_sk, + ws_ext_sales_price AS sales_price, + ws_net_profit AS profit, + cast(0 AS DECIMAL(7, 2)) AS return_amt, + cast(0 AS DECIMAL(7, 2)) AS net_loss + FROM web_sales + UNION ALL + SELECT + ws_web_site_sk AS wsr_web_site_sk, + wr_returned_date_sk AS date_sk, + cast(0 AS DECIMAL(7, 2)) AS sales_price, + cast(0 AS DECIMAL(7, 2)) AS profit, + wr_return_amt AS return_amt, + wr_net_loss AS net_loss + FROM web_returns + LEFT OUTER JOIN web_sales ON + (wr_item_sk = ws_item_sk + AND wr_order_number = ws_order_number) + ) salesreturns, date_dim, web_site + WHERE date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-23' AS DATE) + AND ((cast('2000-08-23' AS DATE) + INTERVAL 14 days)) + AND wsr_web_site_sk = web_site_sk + GROUP BY web_site_id) +SELECT + channel, + id, + sum(sales) AS sales, + sum(returns) AS returns, + sum(profit) AS profit +FROM + (SELECT + 'store channel' AS channel, + concat('store', s_store_id) AS id, + sales, + returns, + (profit - profit_loss) AS profit + FROM ssr + UNION ALL + SELECT + 'catalog channel' AS channel, + concat('catalog_page', cp_catalog_page_id) AS id, + sales, + returns, + (profit - profit_loss) AS profit + FROM csr + UNION ALL + SELECT + 'web channel' AS channel, + concat('web_site', web_site_id) AS id, + sales, + returns, + (profit - profit_loss) AS profit + FROM wsr + ) x +GROUP BY ROLLUP (channel, id) +ORDER BY channel, id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q50.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q50.sql new file mode 100755 index 0000000000000000000000000000000000000000..f1d4b15449edd708d3132092d79081648b3357d2 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q50.sql @@ -0,0 +1,47 @@ +SELECT + s_store_name, + s_company_id, + s_street_number, + s_street_name, + s_street_type, + s_suite_number, + s_city, + s_county, + s_state, + s_zip, + sum(CASE WHEN (sr_returned_date_sk - ss_sold_date_sk <= 30) + THEN 1 + ELSE 0 END) AS `30 days `, + sum(CASE WHEN (sr_returned_date_sk - ss_sold_date_sk > 30) AND + (sr_returned_date_sk - ss_sold_date_sk <= 60) + THEN 1 + ELSE 0 END) AS `31 - 60 days `, + sum(CASE WHEN (sr_returned_date_sk - ss_sold_date_sk > 60) AND + (sr_returned_date_sk - ss_sold_date_sk <= 90) + THEN 1 + ELSE 0 END) AS `61 - 90 days `, + sum(CASE WHEN (sr_returned_date_sk - ss_sold_date_sk > 90) AND + (sr_returned_date_sk - ss_sold_date_sk <= 120) + THEN 1 + ELSE 0 END) AS `91 - 120 days `, + sum(CASE WHEN (sr_returned_date_sk - ss_sold_date_sk > 120) + THEN 1 + ELSE 0 END) AS `>120 days ` +FROM + store_sales, store_returns, store, date_dim d1, date_dim d2 +WHERE + d2.d_year = 2001 + AND d2.d_moy = 8 + AND ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk + AND ss_sold_date_sk = d1.d_date_sk + AND sr_returned_date_sk = d2.d_date_sk + AND ss_customer_sk = sr_customer_sk + AND ss_store_sk = s_store_sk +GROUP BY + s_store_name, s_company_id, s_street_number, s_street_name, s_street_type, + s_suite_number, s_city, s_county, s_state, s_zip +ORDER BY + s_store_name, s_company_id, s_street_number, s_street_name, s_street_type, + s_suite_number, s_city, s_county, s_state, s_zip +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q51.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q51.sql new file mode 100755 index 0000000000000000000000000000000000000000..62b003eb67b9b2cfb6187b4e5ea95a72ad42113f --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q51.sql @@ -0,0 +1,55 @@ +WITH web_v1 AS ( + SELECT + ws_item_sk item_sk, + d_date, + sum(sum(ws_sales_price)) + OVER (PARTITION BY ws_item_sk + ORDER BY d_date + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) cume_sales + FROM web_sales, date_dim + WHERE ws_sold_date_sk = d_date_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + AND ws_item_sk IS NOT NULL + GROUP BY ws_item_sk, d_date), + store_v1 AS ( + SELECT + ss_item_sk item_sk, + d_date, + sum(sum(ss_sales_price)) + OVER (PARTITION BY ss_item_sk + ORDER BY d_date + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) cume_sales + FROM store_sales, date_dim + WHERE ss_sold_date_sk = d_date_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + AND ss_item_sk IS NOT NULL + GROUP BY ss_item_sk, d_date) +SELECT * +FROM (SELECT + item_sk, + d_date, + web_sales, + store_sales, + max(web_sales) + OVER (PARTITION BY item_sk + ORDER BY d_date + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) web_cumulative, + max(store_sales) + OVER (PARTITION BY item_sk + ORDER BY d_date + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) store_cumulative +FROM (SELECT + CASE WHEN web.item_sk IS NOT NULL + THEN web.item_sk + ELSE store.item_sk END item_sk, + CASE WHEN web.d_date IS NOT NULL + THEN web.d_date + ELSE store.d_date END d_date, + web.cume_sales web_sales, + store.cume_sales store_sales +FROM web_v1 web FULL OUTER JOIN store_v1 store ON (web.item_sk = store.item_sk + AND web.d_date = store.d_date) + ) x) y +WHERE web_cumulative > store_cumulative +ORDER BY item_sk, d_date +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q52.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q52.sql new file mode 100755 index 0000000000000000000000000000000000000000..467d1ae05045579d7fb16526c2944049b126747b --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q52.sql @@ -0,0 +1,14 @@ +SELECT + dt.d_year, + item.i_brand_id brand_id, + item.i_brand brand, + sum(ss_ext_sales_price) ext_price +FROM date_dim dt, store_sales, item +WHERE dt.d_date_sk = store_sales.ss_sold_date_sk + AND store_sales.ss_item_sk = item.i_item_sk + AND item.i_manager_id = 1 + AND dt.d_moy = 11 + AND dt.d_year = 2000 +GROUP BY dt.d_year, item.i_brand, item.i_brand_id +ORDER BY dt.d_year, ext_price DESC, brand_id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q53.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q53.sql new file mode 100755 index 0000000000000000000000000000000000000000..b42c68dcf871b902df1988cbeb90debf696ade34 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q53.sql @@ -0,0 +1,30 @@ +SELECT * +FROM + (SELECT + i_manufact_id, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) + OVER (PARTITION BY i_manufact_id) avg_quarterly_sales + FROM item, store_sales, date_dim, store + WHERE ss_item_sk = i_item_sk AND + ss_sold_date_sk = d_date_sk AND + ss_store_sk = s_store_sk AND + d_month_seq IN (1200, 1200 + 1, 1200 + 2, 1200 + 3, 1200 + 4, 1200 + 5, 1200 + 6, + 1200 + 7, 1200 + 8, 1200 + 9, 1200 + 10, 1200 + 11) AND + ((i_category IN ('Books', 'Children', 'Electronics') AND + i_class IN ('personal', 'portable', 'reference', 'self-help') AND + i_brand IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', + 'exportiunivamalg #9', 'scholaramalgamalg #9')) + OR + (i_category IN ('Women', 'Music', 'Men') AND + i_class IN ('accessories', 'classical', 'fragrances', 'pants') AND + i_brand IN ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', + 'importoamalg #1'))) + GROUP BY i_manufact_id, d_qoy) tmp1 +WHERE CASE WHEN avg_quarterly_sales > 0 + THEN abs(sum_sales - avg_quarterly_sales) / avg_quarterly_sales + ELSE NULL END > 0.1 +ORDER BY avg_quarterly_sales, + sum_sales, + i_manufact_id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q54.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q54.sql new file mode 100755 index 0000000000000000000000000000000000000000..897237fb6e10b4e3e1afb9dd58240a14176ff996 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q54.sql @@ -0,0 +1,61 @@ +WITH my_customers AS ( + SELECT DISTINCT + c_customer_sk, + c_current_addr_sk + FROM + (SELECT + cs_sold_date_sk sold_date_sk, + cs_bill_customer_sk customer_sk, + cs_item_sk item_sk + FROM catalog_sales + UNION ALL + SELECT + ws_sold_date_sk sold_date_sk, + ws_bill_customer_sk customer_sk, + ws_item_sk item_sk + FROM web_sales + ) cs_or_ws_sales, + item, + date_dim, + customer + WHERE sold_date_sk = d_date_sk + AND item_sk = i_item_sk + AND i_category = 'Women' + AND i_class = 'maternity' + AND c_customer_sk = cs_or_ws_sales.customer_sk + AND d_moy = 12 + AND d_year = 1998 +) + , my_revenue AS ( + SELECT + c_customer_sk, + sum(ss_ext_sales_price) AS revenue + FROM my_customers, + store_sales, + customer_address, + store, + date_dim + WHERE c_current_addr_sk = ca_address_sk + AND ca_county = s_county + AND ca_state = s_state + AND ss_sold_date_sk = d_date_sk + AND c_customer_sk = ss_customer_sk + AND d_month_seq BETWEEN (SELECT DISTINCT d_month_seq + 1 + FROM date_dim + WHERE d_year = 1998 AND d_moy = 12) + AND (SELECT DISTINCT d_month_seq + 3 + FROM date_dim + WHERE d_year = 1998 AND d_moy = 12) + GROUP BY c_customer_sk +) + , segments AS +(SELECT cast((revenue / 50) AS INT) AS segment + FROM my_revenue) +SELECT + segment, + count(*) AS num_customers, + segment * 50 AS segment_base +FROM segments +GROUP BY segment +ORDER BY segment, num_customers +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q55.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q55.sql new file mode 100755 index 0000000000000000000000000000000000000000..bc5d888c9ac5852c80b8ab2638ca8049d7ea4e4e --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q55.sql @@ -0,0 +1,13 @@ +SELECT + i_brand_id brand_id, + i_brand brand, + sum(ss_ext_sales_price) ext_price +FROM date_dim, store_sales, item +WHERE d_date_sk = ss_sold_date_sk + AND ss_item_sk = i_item_sk + AND i_manager_id = 28 + AND d_moy = 11 + AND d_year = 1999 +GROUP BY i_brand, i_brand_id +ORDER BY ext_price DESC, brand_id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q56.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q56.sql new file mode 100755 index 0000000000000000000000000000000000000000..2fa1738dcfee6e4aaa1fa48c3901e953d83e5407 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q56.sql @@ -0,0 +1,65 @@ +WITH ss AS ( + SELECT + i_item_id, + sum(ss_ext_sales_price) total_sales + FROM + store_sales, date_dim, customer_address, item + WHERE + i_item_id IN (SELECT i_item_id + FROM item + WHERE i_color IN ('slate', 'blanched', 'burnished')) + AND ss_item_sk = i_item_sk + AND ss_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 2 + AND ss_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_item_id), + cs AS ( + SELECT + i_item_id, + sum(cs_ext_sales_price) total_sales + FROM + catalog_sales, date_dim, customer_address, item + WHERE + i_item_id IN (SELECT i_item_id + FROM item + WHERE i_color IN ('slate', 'blanched', 'burnished')) + AND cs_item_sk = i_item_sk + AND cs_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 2 + AND cs_bill_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_item_id), + ws AS ( + SELECT + i_item_id, + sum(ws_ext_sales_price) total_sales + FROM + web_sales, date_dim, customer_address, item + WHERE + i_item_id IN (SELECT i_item_id + FROM item + WHERE i_color IN ('slate', 'blanched', 'burnished')) + AND ws_item_sk = i_item_sk + AND ws_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 2 + AND ws_bill_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_item_id) +SELECT + i_item_id, + sum(total_sales) total_sales +FROM (SELECT * + FROM ss + UNION ALL + SELECT * + FROM cs + UNION ALL + SELECT * + FROM ws) tmp1 +GROUP BY i_item_id +ORDER BY total_sales +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q57.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q57.sql new file mode 100755 index 0000000000000000000000000000000000000000..cf70d4b905b553f0488aa60eebc7207c4b8e129c --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q57.sql @@ -0,0 +1,56 @@ +WITH v1 AS ( + SELECT + i_category, + i_brand, + cc_name, + d_year, + d_moy, + sum(cs_sales_price) sum_sales, + avg(sum(cs_sales_price)) + OVER + (PARTITION BY i_category, i_brand, cc_name, d_year) + avg_monthly_sales, + rank() + OVER + (PARTITION BY i_category, i_brand, cc_name + ORDER BY d_year, d_moy) rn + FROM item, catalog_sales, date_dim, call_center + WHERE cs_item_sk = i_item_sk AND + cs_sold_date_sk = d_date_sk AND + cc_call_center_sk = cs_call_center_sk AND + ( + d_year = 1999 OR + (d_year = 1999 - 1 AND d_moy = 12) OR + (d_year = 1999 + 1 AND d_moy = 1) + ) + GROUP BY i_category, i_brand, + cc_name, d_year, d_moy), + v2 AS ( + SELECT + v1.i_category, + v1.i_brand, + v1.cc_name, + v1.d_year, + v1.d_moy, + v1.avg_monthly_sales, + v1.sum_sales, + v1_lag.sum_sales psum, + v1_lead.sum_sales nsum + FROM v1, v1 v1_lag, v1 v1_lead + WHERE v1.i_category = v1_lag.i_category AND + v1.i_category = v1_lead.i_category AND + v1.i_brand = v1_lag.i_brand AND + v1.i_brand = v1_lead.i_brand AND + v1.cc_name = v1_lag.cc_name AND + v1.cc_name = v1_lead.cc_name AND + v1.rn = v1_lag.rn + 1 AND + v1.rn = v1_lead.rn - 1) +SELECT * +FROM v2 +WHERE d_year = 1999 AND + avg_monthly_sales > 0 AND + CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, 3 +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q58.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q58.sql new file mode 100755 index 0000000000000000000000000000000000000000..5f63f33dc927cb30d9b376848a68fccf4dfbacee --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q58.sql @@ -0,0 +1,59 @@ +WITH ss_items AS +(SELECT + i_item_id item_id, + sum(ss_ext_sales_price) ss_item_rev + FROM store_sales, item, date_dim + WHERE ss_item_sk = i_item_sk + AND d_date IN (SELECT d_date + FROM date_dim + WHERE d_week_seq = (SELECT d_week_seq + FROM date_dim + WHERE d_date = '2000-01-03')) + AND ss_sold_date_sk = d_date_sk + GROUP BY i_item_id), + cs_items AS + (SELECT + i_item_id item_id, + sum(cs_ext_sales_price) cs_item_rev + FROM catalog_sales, item, date_dim + WHERE cs_item_sk = i_item_sk + AND d_date IN (SELECT d_date + FROM date_dim + WHERE d_week_seq = (SELECT d_week_seq + FROM date_dim + WHERE d_date = '2000-01-03')) + AND cs_sold_date_sk = d_date_sk + GROUP BY i_item_id), + ws_items AS + (SELECT + i_item_id item_id, + sum(ws_ext_sales_price) ws_item_rev + FROM web_sales, item, date_dim + WHERE ws_item_sk = i_item_sk + AND d_date IN (SELECT d_date + FROM date_dim + WHERE d_week_seq = (SELECT d_week_seq + FROM date_dim + WHERE d_date = '2000-01-03')) + AND ws_sold_date_sk = d_date_sk + GROUP BY i_item_id) +SELECT + ss_items.item_id, + ss_item_rev, + ss_item_rev / (ss_item_rev + cs_item_rev + ws_item_rev) / 3 * 100 ss_dev, + cs_item_rev, + cs_item_rev / (ss_item_rev + cs_item_rev + ws_item_rev) / 3 * 100 cs_dev, + ws_item_rev, + ws_item_rev / (ss_item_rev + cs_item_rev + ws_item_rev) / 3 * 100 ws_dev, + (ss_item_rev + cs_item_rev + ws_item_rev) / 3 average +FROM ss_items, cs_items, ws_items +WHERE ss_items.item_id = cs_items.item_id + AND ss_items.item_id = ws_items.item_id + AND ss_item_rev BETWEEN 0.9 * cs_item_rev AND 1.1 * cs_item_rev + AND ss_item_rev BETWEEN 0.9 * ws_item_rev AND 1.1 * ws_item_rev + AND cs_item_rev BETWEEN 0.9 * ss_item_rev AND 1.1 * ss_item_rev + AND cs_item_rev BETWEEN 0.9 * ws_item_rev AND 1.1 * ws_item_rev + AND ws_item_rev BETWEEN 0.9 * ss_item_rev AND 1.1 * ss_item_rev + AND ws_item_rev BETWEEN 0.9 * cs_item_rev AND 1.1 * cs_item_rev +ORDER BY item_id, ss_item_rev +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q59.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q59.sql new file mode 100755 index 0000000000000000000000000000000000000000..3cef2027680b077cfdc86661b2867d584650ca03 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q59.sql @@ -0,0 +1,75 @@ +WITH wss AS +(SELECT + d_week_seq, + ss_store_sk, + sum(CASE WHEN (d_day_name = 'Sunday') + THEN ss_sales_price + ELSE NULL END) sun_sales, + sum(CASE WHEN (d_day_name = 'Monday') + THEN ss_sales_price + ELSE NULL END) mon_sales, + sum(CASE WHEN (d_day_name = 'Tuesday') + THEN ss_sales_price + ELSE NULL END) tue_sales, + sum(CASE WHEN (d_day_name = 'Wednesday') + THEN ss_sales_price + ELSE NULL END) wed_sales, + sum(CASE WHEN (d_day_name = 'Thursday') + THEN ss_sales_price + ELSE NULL END) thu_sales, + sum(CASE WHEN (d_day_name = 'Friday') + THEN ss_sales_price + ELSE NULL END) fri_sales, + sum(CASE WHEN (d_day_name = 'Saturday') + THEN ss_sales_price + ELSE NULL END) sat_sales + FROM store_sales, date_dim + WHERE d_date_sk = ss_sold_date_sk + GROUP BY d_week_seq, ss_store_sk +) +SELECT + s_store_name1, + s_store_id1, + d_week_seq1, + sun_sales1 / sun_sales2, + mon_sales1 / mon_sales2, + tue_sales1 / tue_sales2, + wed_sales1 / wed_sales2, + thu_sales1 / thu_sales2, + fri_sales1 / fri_sales2, + sat_sales1 / sat_sales2 +FROM + (SELECT + s_store_name s_store_name1, + wss.d_week_seq d_week_seq1, + s_store_id s_store_id1, + sun_sales sun_sales1, + mon_sales mon_sales1, + tue_sales tue_sales1, + wed_sales wed_sales1, + thu_sales thu_sales1, + fri_sales fri_sales1, + sat_sales sat_sales1 + FROM wss, store, date_dim d + WHERE d.d_week_seq = wss.d_week_seq AND + ss_store_sk = s_store_sk AND + d_month_seq BETWEEN 1212 AND 1212 + 11) y, + (SELECT + s_store_name s_store_name2, + wss.d_week_seq d_week_seq2, + s_store_id s_store_id2, + sun_sales sun_sales2, + mon_sales mon_sales2, + tue_sales tue_sales2, + wed_sales wed_sales2, + thu_sales thu_sales2, + fri_sales fri_sales2, + sat_sales sat_sales2 + FROM wss, store, date_dim d + WHERE d.d_week_seq = wss.d_week_seq AND + ss_store_sk = s_store_sk AND + d_month_seq BETWEEN 1212 + 12 AND 1212 + 23) x +WHERE s_store_id1 = s_store_id2 + AND d_week_seq1 = d_week_seq2 - 52 +ORDER BY s_store_name1, s_store_id1, d_week_seq1 +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q6.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q6.sql new file mode 100755 index 0000000000000000000000000000000000000000..f0f5cf05aebda3cfaf8d22d3250516c76f4169d8 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q6.sql @@ -0,0 +1,21 @@ +SELECT + a.ca_state state, + count(*) cnt +FROM + customer_address a, customer c, store_sales s, date_dim d, item i +WHERE a.ca_address_sk = c.c_current_addr_sk + AND c.c_customer_sk = s.ss_customer_sk + AND s.ss_sold_date_sk = d.d_date_sk + AND s.ss_item_sk = i.i_item_sk + AND d.d_month_seq = + (SELECT DISTINCT (d_month_seq) + FROM date_dim + WHERE d_year = 2000 AND d_moy = 1) + AND i.i_current_price > 1.2 * + (SELECT avg(j.i_current_price) + FROM item j + WHERE j.i_category = i.i_category) +GROUP BY a.ca_state +HAVING count(*) >= 10 +ORDER BY cnt +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q60.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q60.sql new file mode 100755 index 0000000000000000000000000000000000000000..41b963f44ba1346783a025be1c3576574377ac97 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q60.sql @@ -0,0 +1,62 @@ +WITH ss AS ( + SELECT + i_item_id, + sum(ss_ext_sales_price) total_sales + FROM store_sales, date_dim, customer_address, item + WHERE + i_item_id IN (SELECT i_item_id + FROM item + WHERE i_category IN ('Music')) + AND ss_item_sk = i_item_sk + AND ss_sold_date_sk = d_date_sk + AND d_year = 1998 + AND d_moy = 9 + AND ss_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_item_id), + cs AS ( + SELECT + i_item_id, + sum(cs_ext_sales_price) total_sales + FROM catalog_sales, date_dim, customer_address, item + WHERE + i_item_id IN (SELECT i_item_id + FROM item + WHERE i_category IN ('Music')) + AND cs_item_sk = i_item_sk + AND cs_sold_date_sk = d_date_sk + AND d_year = 1998 + AND d_moy = 9 + AND cs_bill_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_item_id), + ws AS ( + SELECT + i_item_id, + sum(ws_ext_sales_price) total_sales + FROM web_sales, date_dim, customer_address, item + WHERE + i_item_id IN (SELECT i_item_id + FROM item + WHERE i_category IN ('Music')) + AND ws_item_sk = i_item_sk + AND ws_sold_date_sk = d_date_sk + AND d_year = 1998 + AND d_moy = 9 + AND ws_bill_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_item_id) +SELECT + i_item_id, + sum(total_sales) total_sales +FROM (SELECT * + FROM ss + UNION ALL + SELECT * + FROM cs + UNION ALL + SELECT * + FROM ws) tmp1 +GROUP BY i_item_id +ORDER BY i_item_id, total_sales +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q61.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q61.sql new file mode 100755 index 0000000000000000000000000000000000000000..b0a872b4b80e103eedf36d925980441c19e23662 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q61.sql @@ -0,0 +1,33 @@ +SELECT + promotions, + total, + cast(promotions AS DECIMAL(15, 4)) / cast(total AS DECIMAL(15, 4)) * 100 +FROM + (SELECT sum(ss_ext_sales_price) promotions + FROM store_sales, store, promotion, date_dim, customer, customer_address, item + WHERE ss_sold_date_sk = d_date_sk + AND ss_store_sk = s_store_sk + AND ss_promo_sk = p_promo_sk + AND ss_customer_sk = c_customer_sk + AND ca_address_sk = c_current_addr_sk + AND ss_item_sk = i_item_sk + AND ca_gmt_offset = -5 + AND i_category = 'Jewelry' + AND (p_channel_dmail = 'Y' OR p_channel_email = 'Y' OR p_channel_tv = 'Y') + AND s_gmt_offset = -5 + AND d_year = 1998 + AND d_moy = 11) promotional_sales, + (SELECT sum(ss_ext_sales_price) total + FROM store_sales, store, date_dim, customer, customer_address, item + WHERE ss_sold_date_sk = d_date_sk + AND ss_store_sk = s_store_sk + AND ss_customer_sk = c_customer_sk + AND ca_address_sk = c_current_addr_sk + AND ss_item_sk = i_item_sk + AND ca_gmt_offset = -5 + AND i_category = 'Jewelry' + AND s_gmt_offset = -5 + AND d_year = 1998 + AND d_moy = 11) all_sales +ORDER BY promotions, total +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q62.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q62.sql new file mode 100755 index 0000000000000000000000000000000000000000..8a414f154bdc87157ff28f69699f29c0a4de730e --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q62.sql @@ -0,0 +1,35 @@ +SELECT + substr(w_warehouse_name, 1, 20), + sm_type, + web_name, + sum(CASE WHEN (ws_ship_date_sk - ws_sold_date_sk <= 30) + THEN 1 + ELSE 0 END) AS `30 days `, + sum(CASE WHEN (ws_ship_date_sk - ws_sold_date_sk > 30) AND + (ws_ship_date_sk - ws_sold_date_sk <= 60) + THEN 1 + ELSE 0 END) AS `31 - 60 days `, + sum(CASE WHEN (ws_ship_date_sk - ws_sold_date_sk > 60) AND + (ws_ship_date_sk - ws_sold_date_sk <= 90) + THEN 1 + ELSE 0 END) AS `61 - 90 days `, + sum(CASE WHEN (ws_ship_date_sk - ws_sold_date_sk > 90) AND + (ws_ship_date_sk - ws_sold_date_sk <= 120) + THEN 1 + ELSE 0 END) AS `91 - 120 days `, + sum(CASE WHEN (ws_ship_date_sk - ws_sold_date_sk > 120) + THEN 1 + ELSE 0 END) AS `>120 days ` +FROM + web_sales, warehouse, ship_mode, web_site, date_dim +WHERE + d_month_seq BETWEEN 1200 AND 1200 + 11 + AND ws_ship_date_sk = d_date_sk + AND ws_warehouse_sk = w_warehouse_sk + AND ws_ship_mode_sk = sm_ship_mode_sk + AND ws_web_site_sk = web_site_sk +GROUP BY + substr(w_warehouse_name, 1, 20), sm_type, web_name +ORDER BY + substr(w_warehouse_name, 1, 20), sm_type, web_name +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q63.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q63.sql new file mode 100755 index 0000000000000000000000000000000000000000..ef6867e0a945163fb3c54a8dde014cd4e6094af6 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q63.sql @@ -0,0 +1,31 @@ +SELECT * +FROM (SELECT + i_manager_id, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) + OVER (PARTITION BY i_manager_id) avg_monthly_sales +FROM item + , store_sales + , date_dim + , store +WHERE ss_item_sk = i_item_sk + AND ss_sold_date_sk = d_date_sk + AND ss_store_sk = s_store_sk + AND d_month_seq IN (1200, 1200 + 1, 1200 + 2, 1200 + 3, 1200 + 4, 1200 + 5, 1200 + 6, 1200 + 7, + 1200 + 8, 1200 + 9, 1200 + 10, 1200 + 11) + AND ((i_category IN ('Books', 'Children', 'Electronics') + AND i_class IN ('personal', 'portable', 'refernece', 'self-help') + AND i_brand IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', + 'exportiunivamalg #9', 'scholaramalgamalg #9')) + OR (i_category IN ('Women', 'Music', 'Men') + AND i_class IN ('accessories', 'classical', 'fragrances', 'pants') + AND i_brand IN ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', + 'importoamalg #1'))) +GROUP BY i_manager_id, d_moy) tmp1 +WHERE CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY i_manager_id + , avg_monthly_sales + , sum_sales +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q64.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q64.sql new file mode 100755 index 0000000000000000000000000000000000000000..8ec1d31b61afeab401aa79ce04257dd5cd2a8661 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q64.sql @@ -0,0 +1,92 @@ +WITH cs_ui AS +(SELECT + cs_item_sk, + sum(cs_ext_list_price) AS sale, + sum(cr_refunded_cash + cr_reversed_charge + cr_store_credit) AS refund + FROM catalog_sales + , catalog_returns + WHERE cs_item_sk = cr_item_sk + AND cs_order_number = cr_order_number + GROUP BY cs_item_sk + HAVING sum(cs_ext_list_price) > 2 * sum(cr_refunded_cash + cr_reversed_charge + cr_store_credit)), + cross_sales AS + (SELECT + i_product_name product_name, + i_item_sk item_sk, + s_store_name store_name, + s_zip store_zip, + ad1.ca_street_number b_street_number, + ad1.ca_street_name b_streen_name, + ad1.ca_city b_city, + ad1.ca_zip b_zip, + ad2.ca_street_number c_street_number, + ad2.ca_street_name c_street_name, + ad2.ca_city c_city, + ad2.ca_zip c_zip, + d1.d_year AS syear, + d2.d_year AS fsyear, + d3.d_year s2year, + count(*) cnt, + sum(ss_wholesale_cost) s1, + sum(ss_list_price) s2, + sum(ss_coupon_amt) s3 + FROM store_sales, store_returns, cs_ui, date_dim d1, date_dim d2, date_dim d3, + store, customer, customer_demographics cd1, customer_demographics cd2, + promotion, household_demographics hd1, household_demographics hd2, + customer_address ad1, customer_address ad2, income_band ib1, income_band ib2, item + WHERE ss_store_sk = s_store_sk AND + ss_sold_date_sk = d1.d_date_sk AND + ss_customer_sk = c_customer_sk AND + ss_cdemo_sk = cd1.cd_demo_sk AND + ss_hdemo_sk = hd1.hd_demo_sk AND + ss_addr_sk = ad1.ca_address_sk AND + ss_item_sk = i_item_sk AND + ss_item_sk = sr_item_sk AND + ss_ticket_number = sr_ticket_number AND + ss_item_sk = cs_ui.cs_item_sk AND + c_current_cdemo_sk = cd2.cd_demo_sk AND + c_current_hdemo_sk = hd2.hd_demo_sk AND + c_current_addr_sk = ad2.ca_address_sk AND + c_first_sales_date_sk = d2.d_date_sk AND + c_first_shipto_date_sk = d3.d_date_sk AND + ss_promo_sk = p_promo_sk AND + hd1.hd_income_band_sk = ib1.ib_income_band_sk AND + hd2.hd_income_band_sk = ib2.ib_income_band_sk AND + cd1.cd_marital_status <> cd2.cd_marital_status AND + i_color IN ('purple', 'burlywood', 'indian', 'spring', 'floral', 'medium') AND + i_current_price BETWEEN 64 AND 64 + 10 AND + i_current_price BETWEEN 64 + 1 AND 64 + 15 + GROUP BY i_product_name, i_item_sk, s_store_name, s_zip, ad1.ca_street_number, + ad1.ca_street_name, ad1.ca_city, ad1.ca_zip, ad2.ca_street_number, + ad2.ca_street_name, ad2.ca_city, ad2.ca_zip, d1.d_year, d2.d_year, d3.d_year + ) +SELECT + cs1.product_name, + cs1.store_name, + cs1.store_zip, + cs1.b_street_number, + cs1.b_streen_name, + cs1.b_city, + cs1.b_zip, + cs1.c_street_number, + cs1.c_street_name, + cs1.c_city, + cs1.c_zip, + cs1.syear, + cs1.cnt, + cs1.s1, + cs1.s2, + cs1.s3, + cs2.s1, + cs2.s2, + cs2.s3, + cs2.syear, + cs2.cnt +FROM cross_sales cs1, cross_sales cs2 +WHERE cs1.item_sk = cs2.item_sk AND + cs1.syear = 1999 AND + cs2.syear = 1999 + 1 AND + cs2.cnt <= cs1.cnt AND + cs1.store_name = cs2.store_name AND + cs1.store_zip = cs2.store_zip +ORDER BY cs1.product_name, cs1.store_name, cs2.cnt diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q65.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q65.sql new file mode 100755 index 0000000000000000000000000000000000000000..aad04be1bcdf04d3fc97e57bd7a52e23ea95b7e8 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q65.sql @@ -0,0 +1,33 @@ +SELECT + s_store_name, + i_item_desc, + sc.revenue, + i_current_price, + i_wholesale_cost, + i_brand +FROM store, item, + (SELECT + ss_store_sk, + avg(revenue) AS ave + FROM + (SELECT + ss_store_sk, + ss_item_sk, + sum(ss_sales_price) AS revenue + FROM store_sales, date_dim + WHERE ss_sold_date_sk = d_date_sk AND d_month_seq BETWEEN 1176 AND 1176 + 11 + GROUP BY ss_store_sk, ss_item_sk) sa + GROUP BY ss_store_sk) sb, + (SELECT + ss_store_sk, + ss_item_sk, + sum(ss_sales_price) AS revenue + FROM store_sales, date_dim + WHERE ss_sold_date_sk = d_date_sk AND d_month_seq BETWEEN 1176 AND 1176 + 11 + GROUP BY ss_store_sk, ss_item_sk) sc +WHERE sb.ss_store_sk = sc.ss_store_sk AND + sc.revenue <= 0.1 * sb.ave AND + s_store_sk = sc.ss_store_sk AND + i_item_sk = sc.ss_item_sk +ORDER BY s_store_name, i_item_desc +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q66.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q66.sql new file mode 100755 index 0000000000000000000000000000000000000000..f826b4164372a79a3b8d3c191483f2a89e753551 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q66.sql @@ -0,0 +1,240 @@ +SELECT + w_warehouse_name, + w_warehouse_sq_ft, + w_city, + w_county, + w_state, + w_country, + ship_carriers, + year, + sum(jan_sales) AS jan_sales, + sum(feb_sales) AS feb_sales, + sum(mar_sales) AS mar_sales, + sum(apr_sales) AS apr_sales, + sum(may_sales) AS may_sales, + sum(jun_sales) AS jun_sales, + sum(jul_sales) AS jul_sales, + sum(aug_sales) AS aug_sales, + sum(sep_sales) AS sep_sales, + sum(oct_sales) AS oct_sales, + sum(nov_sales) AS nov_sales, + sum(dec_sales) AS dec_sales, + sum(jan_sales / w_warehouse_sq_ft) AS jan_sales_per_sq_foot, + sum(feb_sales / w_warehouse_sq_ft) AS feb_sales_per_sq_foot, + sum(mar_sales / w_warehouse_sq_ft) AS mar_sales_per_sq_foot, + sum(apr_sales / w_warehouse_sq_ft) AS apr_sales_per_sq_foot, + sum(may_sales / w_warehouse_sq_ft) AS may_sales_per_sq_foot, + sum(jun_sales / w_warehouse_sq_ft) AS jun_sales_per_sq_foot, + sum(jul_sales / w_warehouse_sq_ft) AS jul_sales_per_sq_foot, + sum(aug_sales / w_warehouse_sq_ft) AS aug_sales_per_sq_foot, + sum(sep_sales / w_warehouse_sq_ft) AS sep_sales_per_sq_foot, + sum(oct_sales / w_warehouse_sq_ft) AS oct_sales_per_sq_foot, + sum(nov_sales / w_warehouse_sq_ft) AS nov_sales_per_sq_foot, + sum(dec_sales / w_warehouse_sq_ft) AS dec_sales_per_sq_foot, + sum(jan_net) AS jan_net, + sum(feb_net) AS feb_net, + sum(mar_net) AS mar_net, + sum(apr_net) AS apr_net, + sum(may_net) AS may_net, + sum(jun_net) AS jun_net, + sum(jul_net) AS jul_net, + sum(aug_net) AS aug_net, + sum(sep_net) AS sep_net, + sum(oct_net) AS oct_net, + sum(nov_net) AS nov_net, + sum(dec_net) AS dec_net +FROM ( + (SELECT + w_warehouse_name, + w_warehouse_sq_ft, + w_city, + w_county, + w_state, + w_country, + concat('DHL', ',', 'BARIAN') AS ship_carriers, + d_year AS year, + sum(CASE WHEN d_moy = 1 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS jan_sales, + sum(CASE WHEN d_moy = 2 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS feb_sales, + sum(CASE WHEN d_moy = 3 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS mar_sales, + sum(CASE WHEN d_moy = 4 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS apr_sales, + sum(CASE WHEN d_moy = 5 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS may_sales, + sum(CASE WHEN d_moy = 6 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS jun_sales, + sum(CASE WHEN d_moy = 7 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS jul_sales, + sum(CASE WHEN d_moy = 8 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS aug_sales, + sum(CASE WHEN d_moy = 9 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS sep_sales, + sum(CASE WHEN d_moy = 10 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS oct_sales, + sum(CASE WHEN d_moy = 11 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS nov_sales, + sum(CASE WHEN d_moy = 12 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS dec_sales, + sum(CASE WHEN d_moy = 1 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS jan_net, + sum(CASE WHEN d_moy = 2 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS feb_net, + sum(CASE WHEN d_moy = 3 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS mar_net, + sum(CASE WHEN d_moy = 4 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS apr_net, + sum(CASE WHEN d_moy = 5 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS may_net, + sum(CASE WHEN d_moy = 6 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS jun_net, + sum(CASE WHEN d_moy = 7 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS jul_net, + sum(CASE WHEN d_moy = 8 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS aug_net, + sum(CASE WHEN d_moy = 9 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS sep_net, + sum(CASE WHEN d_moy = 10 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS oct_net, + sum(CASE WHEN d_moy = 11 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS nov_net, + sum(CASE WHEN d_moy = 12 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS dec_net + FROM + web_sales, warehouse, date_dim, time_dim, ship_mode + WHERE + ws_warehouse_sk = w_warehouse_sk + AND ws_sold_date_sk = d_date_sk + AND ws_sold_time_sk = t_time_sk + AND ws_ship_mode_sk = sm_ship_mode_sk + AND d_year = 2001 + AND t_time BETWEEN 30838 AND 30838 + 28800 + AND sm_carrier IN ('DHL', 'BARIAN') + GROUP BY + w_warehouse_name, w_warehouse_sq_ft, w_city, w_county, w_state, w_country, d_year) + UNION ALL + (SELECT + w_warehouse_name, + w_warehouse_sq_ft, + w_city, + w_county, + w_state, + w_country, + concat('DHL', ',', 'BARIAN') AS ship_carriers, + d_year AS year, + sum(CASE WHEN d_moy = 1 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS jan_sales, + sum(CASE WHEN d_moy = 2 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS feb_sales, + sum(CASE WHEN d_moy = 3 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS mar_sales, + sum(CASE WHEN d_moy = 4 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS apr_sales, + sum(CASE WHEN d_moy = 5 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS may_sales, + sum(CASE WHEN d_moy = 6 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS jun_sales, + sum(CASE WHEN d_moy = 7 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS jul_sales, + sum(CASE WHEN d_moy = 8 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS aug_sales, + sum(CASE WHEN d_moy = 9 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS sep_sales, + sum(CASE WHEN d_moy = 10 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS oct_sales, + sum(CASE WHEN d_moy = 11 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS nov_sales, + sum(CASE WHEN d_moy = 12 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS dec_sales, + sum(CASE WHEN d_moy = 1 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS jan_net, + sum(CASE WHEN d_moy = 2 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS feb_net, + sum(CASE WHEN d_moy = 3 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS mar_net, + sum(CASE WHEN d_moy = 4 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS apr_net, + sum(CASE WHEN d_moy = 5 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS may_net, + sum(CASE WHEN d_moy = 6 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS jun_net, + sum(CASE WHEN d_moy = 7 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS jul_net, + sum(CASE WHEN d_moy = 8 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS aug_net, + sum(CASE WHEN d_moy = 9 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS sep_net, + sum(CASE WHEN d_moy = 10 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS oct_net, + sum(CASE WHEN d_moy = 11 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS nov_net, + sum(CASE WHEN d_moy = 12 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS dec_net + FROM + catalog_sales, warehouse, date_dim, time_dim, ship_mode + WHERE + cs_warehouse_sk = w_warehouse_sk + AND cs_sold_date_sk = d_date_sk + AND cs_sold_time_sk = t_time_sk + AND cs_ship_mode_sk = sm_ship_mode_sk + AND d_year = 2001 + AND t_time BETWEEN 30838 AND 30838 + 28800 + AND sm_carrier IN ('DHL', 'BARIAN') + GROUP BY + w_warehouse_name, w_warehouse_sq_ft, w_city, w_county, w_state, w_country, d_year + ) + ) x +GROUP BY + w_warehouse_name, w_warehouse_sq_ft, w_city, w_county, w_state, w_country, + ship_carriers, year +ORDER BY w_warehouse_name +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q67.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q67.sql new file mode 100755 index 0000000000000000000000000000000000000000..f66e2252bdbd46f013e729081c945cccb1268aee --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q67.sql @@ -0,0 +1,38 @@ +SELECT * +FROM + (SELECT + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales, + rank() + OVER (PARTITION BY i_category + ORDER BY sumsales DESC) rk + FROM + (SELECT + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sum(coalesce(ss_sales_price * ss_quantity, 0)) sumsales + FROM store_sales, date_dim, store, item + WHERE ss_sold_date_sk = d_date_sk + AND ss_item_sk = i_item_sk + AND ss_store_sk = s_store_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + GROUP BY ROLLUP (i_category, i_class, i_brand, i_product_name, d_year, d_qoy, + d_moy, s_store_id)) dw1) dw2 +WHERE rk <= 100 +ORDER BY + i_category, i_class, i_brand, i_product_name, d_year, + d_qoy, d_moy, s_store_id, sumsales, rk +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q68.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q68.sql new file mode 100755 index 0000000000000000000000000000000000000000..adb8a7189dad75ac32497de2c85bd654e2ada735 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q68.sql @@ -0,0 +1,34 @@ +SELECT + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number, + extended_price, + extended_tax, + list_price +FROM (SELECT + ss_ticket_number, + ss_customer_sk, + ca_city bought_city, + sum(ss_ext_sales_price) extended_price, + sum(ss_ext_list_price) list_price, + sum(ss_ext_tax) extended_tax +FROM store_sales, date_dim, store, household_demographics, customer_address +WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_store_sk = store.s_store_sk + AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + AND store_sales.ss_addr_sk = customer_address.ca_address_sk + AND date_dim.d_dom BETWEEN 1 AND 2 + AND (household_demographics.hd_dep_count = 4 OR + household_demographics.hd_vehicle_count = 3) + AND date_dim.d_year IN (1999, 1999 + 1, 1999 + 2) + AND store.s_city IN ('Midway', 'Fairview') +GROUP BY ss_ticket_number, ss_customer_sk, ss_addr_sk, ca_city) dn, + customer, + customer_address current_addr +WHERE ss_customer_sk = c_customer_sk + AND customer.c_current_addr_sk = current_addr.ca_address_sk + AND current_addr.ca_city <> bought_city +ORDER BY c_last_name, ss_ticket_number +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q69.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q69.sql new file mode 100755 index 0000000000000000000000000000000000000000..1f0ee64f565a3081eaed1673f21737018cd78917 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q69.sql @@ -0,0 +1,38 @@ +SELECT + cd_gender, + cd_marital_status, + cd_education_status, + count(*) cnt1, + cd_purchase_estimate, + count(*) cnt2, + cd_credit_rating, + count(*) cnt3 +FROM + customer c, customer_address ca, customer_demographics +WHERE + c.c_current_addr_sk = ca.ca_address_sk AND + ca_state IN ('KY', 'GA', 'NM') AND + cd_demo_sk = c.c_current_cdemo_sk AND + exists(SELECT * + FROM store_sales, date_dim + WHERE c.c_customer_sk = ss_customer_sk AND + ss_sold_date_sk = d_date_sk AND + d_year = 2001 AND + d_moy BETWEEN 4 AND 4 + 2) AND + (NOT exists(SELECT * + FROM web_sales, date_dim + WHERE c.c_customer_sk = ws_bill_customer_sk AND + ws_sold_date_sk = d_date_sk AND + d_year = 2001 AND + d_moy BETWEEN 4 AND 4 + 2) AND + NOT exists(SELECT * + FROM catalog_sales, date_dim + WHERE c.c_customer_sk = cs_ship_customer_sk AND + cs_sold_date_sk = d_date_sk AND + d_year = 2001 AND + d_moy BETWEEN 4 AND 4 + 2)) +GROUP BY cd_gender, cd_marital_status, cd_education_status, + cd_purchase_estimate, cd_credit_rating +ORDER BY cd_gender, cd_marital_status, cd_education_status, + cd_purchase_estimate, cd_credit_rating +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q7.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q7.sql new file mode 100755 index 0000000000000000000000000000000000000000..6630a00548403077ff4af98ff921342d4d48ec33 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q7.sql @@ -0,0 +1,19 @@ +SELECT + i_item_id, + avg(ss_quantity) agg1, + avg(ss_list_price) agg2, + avg(ss_coupon_amt) agg3, + avg(ss_sales_price) agg4 +FROM store_sales, customer_demographics, date_dim, item, promotion +WHERE ss_sold_date_sk = d_date_sk AND + ss_item_sk = i_item_sk AND + ss_cdemo_sk = cd_demo_sk AND + ss_promo_sk = p_promo_sk AND + cd_gender = 'M' AND + cd_marital_status = 'S' AND + cd_education_status = 'College' AND + (p_channel_email = 'N' OR p_channel_event = 'N') AND + d_year = 2000 +GROUP BY i_item_id +ORDER BY i_item_id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q70.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q70.sql new file mode 100755 index 0000000000000000000000000000000000000000..625011b212fe06761bd6917660f116d16665b163 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q70.sql @@ -0,0 +1,38 @@ +SELECT + sum(ss_net_profit) AS total_sum, + s_state, + s_county, + grouping(s_state) + grouping(s_county) AS lochierarchy, + rank() + OVER ( + PARTITION BY grouping(s_state) + grouping(s_county), + CASE WHEN grouping(s_county) = 0 + THEN s_state END + ORDER BY sum(ss_net_profit) DESC) AS rank_within_parent +FROM + store_sales, date_dim d1, store +WHERE + d1.d_month_seq BETWEEN 1200 AND 1200 + 11 + AND d1.d_date_sk = ss_sold_date_sk + AND s_store_sk = ss_store_sk + AND s_state IN + (SELECT s_state + FROM + (SELECT + s_state AS s_state, + rank() + OVER (PARTITION BY s_state + ORDER BY sum(ss_net_profit) DESC) AS ranking + FROM store_sales, store, date_dim + WHERE d_month_seq BETWEEN 1200 AND 1200 + 11 + AND d_date_sk = ss_sold_date_sk + AND s_store_sk = ss_store_sk + GROUP BY s_state) tmp1 + WHERE ranking <= 5) +GROUP BY ROLLUP (s_state, s_county) +ORDER BY + lochierarchy DESC + , CASE WHEN lochierarchy = 0 + THEN s_state END + , rank_within_parent +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q71.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q71.sql new file mode 100755 index 0000000000000000000000000000000000000000..8d724b9244e11091dd7be309a1e7c70859718338 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q71.sql @@ -0,0 +1,44 @@ +SELECT + i_brand_id brand_id, + i_brand brand, + t_hour, + t_minute, + sum(ext_price) ext_price +FROM item, + (SELECT + ws_ext_sales_price AS ext_price, + ws_sold_date_sk AS sold_date_sk, + ws_item_sk AS sold_item_sk, + ws_sold_time_sk AS time_sk + FROM web_sales, date_dim + WHERE d_date_sk = ws_sold_date_sk + AND d_moy = 11 + AND d_year = 1999 + UNION ALL + SELECT + cs_ext_sales_price AS ext_price, + cs_sold_date_sk AS sold_date_sk, + cs_item_sk AS sold_item_sk, + cs_sold_time_sk AS time_sk + FROM catalog_sales, date_dim + WHERE d_date_sk = cs_sold_date_sk + AND d_moy = 11 + AND d_year = 1999 + UNION ALL + SELECT + ss_ext_sales_price AS ext_price, + ss_sold_date_sk AS sold_date_sk, + ss_item_sk AS sold_item_sk, + ss_sold_time_sk AS time_sk + FROM store_sales, date_dim + WHERE d_date_sk = ss_sold_date_sk + AND d_moy = 11 + AND d_year = 1999 + ) AS tmp, time_dim +WHERE + sold_item_sk = i_item_sk + AND i_manager_id = 1 + AND time_sk = t_time_sk + AND (t_meal_time = 'breakfast' OR t_meal_time = 'dinner') +GROUP BY i_brand, i_brand_id, t_hour, t_minute +ORDER BY ext_price DESC, brand_id diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q72.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q72.sql new file mode 100755 index 0000000000000000000000000000000000000000..99b3eee54aa1ab1a793cae26e2d2241e4e1bb2c4 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q72.sql @@ -0,0 +1,33 @@ +SELECT + i_item_desc, + w_warehouse_name, + d1.d_week_seq, + count(CASE WHEN p_promo_sk IS NULL + THEN 1 + ELSE 0 END) no_promo, + count(CASE WHEN p_promo_sk IS NOT NULL + THEN 1 + ELSE 0 END) promo, + count(*) total_cnt +FROM catalog_sales + JOIN inventory ON (cs_item_sk = inv_item_sk) + JOIN warehouse ON (w_warehouse_sk = inv_warehouse_sk) + JOIN item ON (i_item_sk = cs_item_sk) + JOIN customer_demographics ON (cs_bill_cdemo_sk = cd_demo_sk) + JOIN household_demographics ON (cs_bill_hdemo_sk = hd_demo_sk) + JOIN date_dim d1 ON (cs_sold_date_sk = d1.d_date_sk) + JOIN date_dim d2 ON (inv_date_sk = d2.d_date_sk) + JOIN date_dim d3 ON (cs_ship_date_sk = d3.d_date_sk) + LEFT OUTER JOIN promotion ON (cs_promo_sk = p_promo_sk) + LEFT OUTER JOIN catalog_returns ON (cr_item_sk = cs_item_sk AND cr_order_number = cs_order_number) +WHERE d1.d_week_seq = d2.d_week_seq + AND inv_quantity_on_hand < cs_quantity + AND d3.d_date > (cast(d1.d_date AS DATE) + interval 5 days) + AND hd_buy_potential = '>10000' + AND d1.d_year = 1999 + AND hd_buy_potential = '>10000' + AND cd_marital_status = 'D' + AND d1.d_year = 1999 +GROUP BY i_item_desc, w_warehouse_name, d1.d_week_seq +ORDER BY total_cnt DESC, i_item_desc, w_warehouse_name, d_week_seq +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q73.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q73.sql new file mode 100755 index 0000000000000000000000000000000000000000..881be2e9024d7219b14bfcf39046f9bdfeccb020 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q73.sql @@ -0,0 +1,30 @@ +SELECT + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +FROM + (SELECT + ss_ticket_number, + ss_customer_sk, + count(*) cnt + FROM store_sales, date_dim, store, household_demographics + WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_store_sk = store.s_store_sk + AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + AND date_dim.d_dom BETWEEN 1 AND 2 + AND (household_demographics.hd_buy_potential = '>10000' OR + household_demographics.hd_buy_potential = 'unknown') + AND household_demographics.hd_vehicle_count > 0 + AND CASE WHEN household_demographics.hd_vehicle_count > 0 + THEN + household_demographics.hd_dep_count / household_demographics.hd_vehicle_count + ELSE NULL END > 1 + AND date_dim.d_year IN (1999, 1999 + 1, 1999 + 2) + AND store.s_county IN ('Williamson County', 'Franklin Parish', 'Bronx County', 'Orange County') + GROUP BY ss_ticket_number, ss_customer_sk) dj, customer +WHERE ss_customer_sk = c_customer_sk + AND cnt BETWEEN 1 AND 5 +ORDER BY cnt DESC diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q74.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q74.sql new file mode 100755 index 0000000000000000000000000000000000000000..154b26d6802a3a5c7199b774abe3e5b6a34d0919 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q74.sql @@ -0,0 +1,58 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + d_year AS year, + sum(ss_net_paid) year_total, + 's' sale_type + FROM + customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk + AND ss_sold_date_sk = d_date_sk + AND d_year IN (2001, 2001 + 1) + GROUP BY + c_customer_id, c_first_name, c_last_name, d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + d_year AS year, + sum(ws_net_paid) year_total, + 'w' sale_type + FROM + customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk + AND ws_sold_date_sk = d_date_sk + AND d_year IN (2001, 2001 + 1) + GROUP BY + c_customer_id, c_first_name, c_last_name, d_year) +SELECT + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name +FROM + year_total t_s_firstyear, year_total t_s_secyear, + year_total t_w_firstyear, year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.year = 2001 + AND t_s_secyear.year = 2001 + 1 + AND t_w_firstyear.year = 2001 + AND t_w_secyear.year = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + ELSE NULL END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + ELSE NULL END +ORDER BY 1, 1, 1 +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q75.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q75.sql new file mode 100755 index 0000000000000000000000000000000000000000..2a143232b5196ca694584d186decfe36efd1c4d3 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q75.sql @@ -0,0 +1,76 @@ +WITH all_sales AS ( + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + SUM(sales_cnt) AS sales_cnt, + SUM(sales_amt) AS sales_amt + FROM ( + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + cs_quantity - COALESCE(cr_return_quantity, 0) AS sales_cnt, + cs_ext_sales_price - COALESCE(cr_return_amount, 0.0) AS sales_amt + FROM catalog_sales + JOIN item ON i_item_sk = cs_item_sk + JOIN date_dim ON d_date_sk = cs_sold_date_sk + LEFT JOIN catalog_returns ON (cs_order_number = cr_order_number + AND cs_item_sk = cr_item_sk) + WHERE i_category = 'Books' + UNION + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + ss_quantity - COALESCE(sr_return_quantity, 0) AS sales_cnt, + ss_ext_sales_price - COALESCE(sr_return_amt, 0.0) AS sales_amt + FROM store_sales + JOIN item ON i_item_sk = ss_item_sk + JOIN date_dim ON d_date_sk = ss_sold_date_sk + LEFT JOIN store_returns ON (ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk) + WHERE i_category = 'Books' + UNION + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + ws_quantity - COALESCE(wr_return_quantity, 0) AS sales_cnt, + ws_ext_sales_price - COALESCE(wr_return_amt, 0.0) AS sales_amt + FROM web_sales + JOIN item ON i_item_sk = ws_item_sk + JOIN date_dim ON d_date_sk = ws_sold_date_sk + LEFT JOIN web_returns ON (ws_order_number = wr_order_number + AND ws_item_sk = wr_item_sk) + WHERE i_category = 'Books') sales_detail + GROUP BY d_year, i_brand_id, i_class_id, i_category_id, i_manufact_id) +SELECT + prev_yr.d_year AS prev_year, + curr_yr.d_year AS year, + curr_yr.i_brand_id, + curr_yr.i_class_id, + curr_yr.i_category_id, + curr_yr.i_manufact_id, + prev_yr.sales_cnt AS prev_yr_cnt, + curr_yr.sales_cnt AS curr_yr_cnt, + curr_yr.sales_cnt - prev_yr.sales_cnt AS sales_cnt_diff, + curr_yr.sales_amt - prev_yr.sales_amt AS sales_amt_diff +FROM all_sales curr_yr, all_sales prev_yr +WHERE curr_yr.i_brand_id = prev_yr.i_brand_id + AND curr_yr.i_class_id = prev_yr.i_class_id + AND curr_yr.i_category_id = prev_yr.i_category_id + AND curr_yr.i_manufact_id = prev_yr.i_manufact_id + AND curr_yr.d_year = 2002 + AND prev_yr.d_year = 2002 - 1 + AND CAST(curr_yr.sales_cnt AS DECIMAL(17, 2)) / CAST(prev_yr.sales_cnt AS DECIMAL(17, 2)) < 0.9 +ORDER BY sales_cnt_diff +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q76.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q76.sql new file mode 100755 index 0000000000000000000000000000000000000000..815fa922be19d570b0b2bf50fd940302856d802c --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q76.sql @@ -0,0 +1,47 @@ +SELECT + channel, + col_name, + d_year, + d_qoy, + i_category, + COUNT(*) sales_cnt, + SUM(ext_sales_price) sales_amt +FROM ( + SELECT + 'store' AS channel, + ss_store_sk col_name, + d_year, + d_qoy, + i_category, + ss_ext_sales_price ext_sales_price + FROM store_sales, item, date_dim + WHERE ss_store_sk IS NULL + AND ss_sold_date_sk = d_date_sk + AND ss_item_sk = i_item_sk + UNION ALL + SELECT + 'web' AS channel, + ws_ship_customer_sk col_name, + d_year, + d_qoy, + i_category, + ws_ext_sales_price ext_sales_price + FROM web_sales, item, date_dim + WHERE ws_ship_customer_sk IS NULL + AND ws_sold_date_sk = d_date_sk + AND ws_item_sk = i_item_sk + UNION ALL + SELECT + 'catalog' AS channel, + cs_ship_addr_sk col_name, + d_year, + d_qoy, + i_category, + cs_ext_sales_price ext_sales_price + FROM catalog_sales, item, date_dim + WHERE cs_ship_addr_sk IS NULL + AND cs_sold_date_sk = d_date_sk + AND cs_item_sk = i_item_sk) foo +GROUP BY channel, col_name, d_year, d_qoy, i_category +ORDER BY channel, col_name, d_year, d_qoy, i_category +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q77.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q77.sql new file mode 100755 index 0000000000000000000000000000000000000000..a69df9fbcd366e8b06da7ae24ee8925ae54706f6 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q77.sql @@ -0,0 +1,100 @@ +WITH ss AS +(SELECT + s_store_sk, + sum(ss_ext_sales_price) AS sales, + sum(ss_net_profit) AS profit + FROM store_sales, date_dim, store + WHERE ss_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-03' AS DATE) AND + (cast('2000-08-03' AS DATE) + INTERVAL 30 days) + AND ss_store_sk = s_store_sk + GROUP BY s_store_sk), + sr AS + (SELECT + s_store_sk, + sum(sr_return_amt) AS returns, + sum(sr_net_loss) AS profit_loss + FROM store_returns, date_dim, store + WHERE sr_returned_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-03' AS DATE) AND + (cast('2000-08-03' AS DATE) + INTERVAL 30 days) + AND sr_store_sk = s_store_sk + GROUP BY s_store_sk), + cs AS + (SELECT + cs_call_center_sk, + sum(cs_ext_sales_price) AS sales, + sum(cs_net_profit) AS profit + FROM catalog_sales, date_dim + WHERE cs_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-03' AS DATE) AND + (cast('2000-08-03' AS DATE) + INTERVAL 30 days) + GROUP BY cs_call_center_sk), + cr AS + (SELECT + sum(cr_return_amount) AS returns, + sum(cr_net_loss) AS profit_loss + FROM catalog_returns, date_dim + WHERE cr_returned_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-03' AS DATE) AND + (cast('2000-08-03' AS DATE) + INTERVAL 30 days)), + ws AS + (SELECT + wp_web_page_sk, + sum(ws_ext_sales_price) AS sales, + sum(ws_net_profit) AS profit + FROM web_sales, date_dim, web_page + WHERE ws_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-03' AS DATE) AND + (cast('2000-08-03' AS DATE) + INTERVAL 30 days) + AND ws_web_page_sk = wp_web_page_sk + GROUP BY wp_web_page_sk), + wr AS + (SELECT + wp_web_page_sk, + sum(wr_return_amt) AS returns, + sum(wr_net_loss) AS profit_loss + FROM web_returns, date_dim, web_page + WHERE wr_returned_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-03' AS DATE) AND + (cast('2000-08-03' AS DATE) + INTERVAL 30 days) + AND wr_web_page_sk = wp_web_page_sk + GROUP BY wp_web_page_sk) +SELECT + channel, + id, + sum(sales) AS sales, + sum(returns) AS returns, + sum(profit) AS profit +FROM + (SELECT + 'store channel' AS channel, + ss.s_store_sk AS id, + sales, + coalesce(returns, 0) AS returns, + (profit - coalesce(profit_loss, 0)) AS profit + FROM ss + LEFT JOIN sr + ON ss.s_store_sk = sr.s_store_sk + UNION ALL + SELECT + 'catalog channel' AS channel, + cs_call_center_sk AS id, + sales, + returns, + (profit - profit_loss) AS profit + FROM cs, cr + UNION ALL + SELECT + 'web channel' AS channel, + ws.wp_web_page_sk AS id, + sales, + coalesce(returns, 0) returns, + (profit - coalesce(profit_loss, 0)) AS profit + FROM ws + LEFT JOIN wr + ON ws.wp_web_page_sk = wr.wp_web_page_sk + ) x +GROUP BY ROLLUP (channel, id) +ORDER BY channel, id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q78.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q78.sql new file mode 100755 index 0000000000000000000000000000000000000000..07b0940e268821d58e048476f6cb72f01cc95cd1 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q78.sql @@ -0,0 +1,64 @@ +WITH ws AS +(SELECT + d_year AS ws_sold_year, + ws_item_sk, + ws_bill_customer_sk ws_customer_sk, + sum(ws_quantity) ws_qty, + sum(ws_wholesale_cost) ws_wc, + sum(ws_sales_price) ws_sp + FROM web_sales + LEFT JOIN web_returns ON wr_order_number = ws_order_number AND ws_item_sk = wr_item_sk + JOIN date_dim ON ws_sold_date_sk = d_date_sk + WHERE wr_order_number IS NULL + GROUP BY d_year, ws_item_sk, ws_bill_customer_sk +), + cs AS + (SELECT + d_year AS cs_sold_year, + cs_item_sk, + cs_bill_customer_sk cs_customer_sk, + sum(cs_quantity) cs_qty, + sum(cs_wholesale_cost) cs_wc, + sum(cs_sales_price) cs_sp + FROM catalog_sales + LEFT JOIN catalog_returns ON cr_order_number = cs_order_number AND cs_item_sk = cr_item_sk + JOIN date_dim ON cs_sold_date_sk = d_date_sk + WHERE cr_order_number IS NULL + GROUP BY d_year, cs_item_sk, cs_bill_customer_sk + ), + ss AS + (SELECT + d_year AS ss_sold_year, + ss_item_sk, + ss_customer_sk, + sum(ss_quantity) ss_qty, + sum(ss_wholesale_cost) ss_wc, + sum(ss_sales_price) ss_sp + FROM store_sales + LEFT JOIN store_returns ON sr_ticket_number = ss_ticket_number AND ss_item_sk = sr_item_sk + JOIN date_dim ON ss_sold_date_sk = d_date_sk + WHERE sr_ticket_number IS NULL + GROUP BY d_year, ss_item_sk, ss_customer_sk + ) +SELECT + round(ss_qty / (coalesce(ws_qty + cs_qty, 1)), 2) ratio, + ss_qty store_qty, + ss_wc store_wholesale_cost, + ss_sp store_sales_price, + coalesce(ws_qty, 0) + coalesce(cs_qty, 0) other_chan_qty, + coalesce(ws_wc, 0) + coalesce(cs_wc, 0) other_chan_wholesale_cost, + coalesce(ws_sp, 0) + coalesce(cs_sp, 0) other_chan_sales_price +FROM ss + LEFT JOIN ws + ON (ws_sold_year = ss_sold_year AND ws_item_sk = ss_item_sk AND ws_customer_sk = ss_customer_sk) + LEFT JOIN cs + ON (cs_sold_year = ss_sold_year AND cs_item_sk = ss_item_sk AND cs_customer_sk = ss_customer_sk) +WHERE coalesce(ws_qty, 0) > 0 AND coalesce(cs_qty, 0) > 0 AND ss_sold_year = 2000 +ORDER BY + ratio, + ss_qty DESC, ss_wc DESC, ss_sp DESC, + other_chan_qty, + other_chan_wholesale_cost, + other_chan_sales_price, + round(ss_qty / (coalesce(ws_qty + cs_qty, 1)), 2) +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q79.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q79.sql new file mode 100755 index 0000000000000000000000000000000000000000..08f86dc2032aab0753ae1a95967c1b552d2c02ea --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q79.sql @@ -0,0 +1,27 @@ +SELECT + c_last_name, + c_first_name, + substr(s_city, 1, 30), + ss_ticket_number, + amt, + profit +FROM + (SELECT + ss_ticket_number, + ss_customer_sk, + store.s_city, + sum(ss_coupon_amt) amt, + sum(ss_net_profit) profit + FROM store_sales, date_dim, store, household_demographics + WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_store_sk = store.s_store_sk + AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + AND (household_demographics.hd_dep_count = 6 OR + household_demographics.hd_vehicle_count > 2) + AND date_dim.d_dow = 1 + AND date_dim.d_year IN (1999, 1999 + 1, 1999 + 2) + AND store.s_number_employees BETWEEN 200 AND 295 + GROUP BY ss_ticket_number, ss_customer_sk, ss_addr_sk, store.s_city) ms, customer +WHERE ss_customer_sk = c_customer_sk +ORDER BY c_last_name, c_first_name, substr(s_city, 1, 30), profit +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q8.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q8.sql new file mode 100755 index 0000000000000000000000000000000000000000..497725111f4f303e7971123c348d2ca024a7853a --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q8.sql @@ -0,0 +1,87 @@ +SELECT + s_store_name, + sum(ss_net_profit) +FROM store_sales, date_dim, store, + (SELECT ca_zip + FROM ( + (SELECT substr(ca_zip, 1, 5) ca_zip + FROM customer_address + WHERE substr(ca_zip, 1, 5) IN ( + '24128','76232','65084','87816','83926','77556','20548', + '26231','43848','15126','91137','61265','98294','25782', + '17920','18426','98235','40081','84093','28577','55565', + '17183','54601','67897','22752','86284','18376','38607', + '45200','21756','29741','96765','23932','89360','29839', + '25989','28898','91068','72550','10390','18845','47770', + '82636','41367','76638','86198','81312','37126','39192', + '88424','72175','81426','53672','10445','42666','66864', + '66708','41248','48583','82276','18842','78890','49448', + '14089','38122','34425','79077','19849','43285','39861', + '66162','77610','13695','99543','83444','83041','12305', + '57665','68341','25003','57834','62878','49130','81096', + '18840','27700','23470','50412','21195','16021','76107', + '71954','68309','18119','98359','64544','10336','86379', + '27068','39736','98569','28915','24206','56529','57647', + '54917','42961','91110','63981','14922','36420','23006', + '67467','32754','30903','20260','31671','51798','72325', + '85816','68621','13955','36446','41766','68806','16725', + '15146','22744','35850','88086','51649','18270','52867', + '39972','96976','63792','11376','94898','13595','10516', + '90225','58943','39371','94945','28587','96576','57855', + '28488','26105','83933','25858','34322','44438','73171', + '30122','34102','22685','71256','78451','54364','13354', + '45375','40558','56458','28286','45266','47305','69399', + '83921','26233','11101','15371','69913','35942','15882', + '25631','24610','44165','99076','33786','70738','26653', + '14328','72305','62496','22152','10144','64147','48425', + '14663','21076','18799','30450','63089','81019','68893', + '24996','51200','51211','45692','92712','70466','79994', + '22437','25280','38935','71791','73134','56571','14060', + '19505','72425','56575','74351','68786','51650','20004', + '18383','76614','11634','18906','15765','41368','73241', + '76698','78567','97189','28545','76231','75691','22246', + '51061','90578','56691','68014','51103','94167','57047', + '14867','73520','15734','63435','25733','35474','24676', + '94627','53535','17879','15559','53268','59166','11928', + '59402','33282','45721','43933','68101','33515','36634', + '71286','19736','58058','55253','67473','41918','19515', + '36495','19430','22351','77191','91393','49156','50298', + '87501','18652','53179','18767','63193','23968','65164', + '68880','21286','72823','58470','67301','13394','31016', + '70372','67030','40604','24317','45748','39127','26065', + '77721','31029','31880','60576','24671','45549','13376', + '50016','33123','19769','22927','97789','46081','72151', + '15723','46136','51949','68100','96888','64528','14171', + '79777','28709','11489','25103','32213','78668','22245', + '15798','27156','37930','62971','21337','51622','67853', + '10567','38415','15455','58263','42029','60279','37125', + '56240','88190','50308','26859','64457','89091','82136', + '62377','36233','63837','58078','17043','30010','60099', + '28810','98025','29178','87343','73273','30469','64034', + '39516','86057','21309','90257','67875','40162','11356', + '73650','61810','72013','30431','22461','19512','13375', + '55307','30625','83849','68908','26689','96451','38193', + '46820','88885','84935','69035','83144','47537','56616', + '94983','48033','69952','25486','61547','27385','61860', + '58048','56910','16807','17871','35258','31387','35458', + '35576')) + INTERSECT + (SELECT ca_zip + FROM + (SELECT + substr(ca_zip, 1, 5) ca_zip, + count(*) cnt + FROM customer_address, customer + WHERE ca_address_sk = c_current_addr_sk AND + c_preferred_cust_flag = 'Y' + GROUP BY ca_zip + HAVING count(*) > 10) A1) + ) A2 + ) V1 +WHERE ss_store_sk = s_store_sk + AND ss_sold_date_sk = d_date_sk + AND d_qoy = 2 AND d_year = 1998 + AND (substr(s_zip, 1, 2) = substr(V1.ca_zip, 1, 2)) +GROUP BY s_store_name +ORDER BY s_store_name +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q80.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q80.sql new file mode 100755 index 0000000000000000000000000000000000000000..433db87d2a858005262d7077eb13b0ba7b4ac616 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q80.sql @@ -0,0 +1,94 @@ +WITH ssr AS +(SELECT + s_store_id AS store_id, + sum(ss_ext_sales_price) AS sales, + sum(coalesce(sr_return_amt, 0)) AS returns, + sum(ss_net_profit - coalesce(sr_net_loss, 0)) AS profit + FROM store_sales + LEFT OUTER JOIN store_returns ON + (ss_item_sk = sr_item_sk AND + ss_ticket_number = sr_ticket_number) + , + date_dim, store, item, promotion + WHERE ss_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-23' AS DATE) + AND (cast('2000-08-23' AS DATE) + INTERVAL 30 days) + AND ss_store_sk = s_store_sk + AND ss_item_sk = i_item_sk + AND i_current_price > 50 + AND ss_promo_sk = p_promo_sk + AND p_channel_tv = 'N' + GROUP BY s_store_id), + csr AS + (SELECT + cp_catalog_page_id AS catalog_page_id, + sum(cs_ext_sales_price) AS sales, + sum(coalesce(cr_return_amount, 0)) AS returns, + sum(cs_net_profit - coalesce(cr_net_loss, 0)) AS profit + FROM catalog_sales + LEFT OUTER JOIN catalog_returns ON + (cs_item_sk = cr_item_sk AND + cs_order_number = cr_order_number) + , + date_dim, catalog_page, item, promotion + WHERE cs_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-23' AS DATE) + AND (cast('2000-08-23' AS DATE) + INTERVAL 30 days) + AND cs_catalog_page_sk = cp_catalog_page_sk + AND cs_item_sk = i_item_sk + AND i_current_price > 50 + AND cs_promo_sk = p_promo_sk + AND p_channel_tv = 'N' + GROUP BY cp_catalog_page_id), + wsr AS + (SELECT + web_site_id, + sum(ws_ext_sales_price) AS sales, + sum(coalesce(wr_return_amt, 0)) AS returns, + sum(ws_net_profit - coalesce(wr_net_loss, 0)) AS profit + FROM web_sales + LEFT OUTER JOIN web_returns ON + (ws_item_sk = wr_item_sk AND ws_order_number = wr_order_number) + , + date_dim, web_site, item, promotion + WHERE ws_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-23' AS DATE) + AND (cast('2000-08-23' AS DATE) + INTERVAL 30 days) + AND ws_web_site_sk = web_site_sk + AND ws_item_sk = i_item_sk + AND i_current_price > 50 + AND ws_promo_sk = p_promo_sk + AND p_channel_tv = 'N' + GROUP BY web_site_id) +SELECT + channel, + id, + sum(sales) AS sales, + sum(returns) AS returns, + sum(profit) AS profit +FROM (SELECT + 'store channel' AS channel, + concat('store', store_id) AS id, + sales, + returns, + profit + FROM ssr + UNION ALL + SELECT + 'catalog channel' AS channel, + concat('catalog_page', catalog_page_id) AS id, + sales, + returns, + profit + FROM csr + UNION ALL + SELECT + 'web channel' AS channel, + concat('web_site', web_site_id) AS id, + sales, + returns, + profit + FROM wsr) x +GROUP BY ROLLUP (channel, id) +ORDER BY channel, id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q81.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q81.sql new file mode 100755 index 0000000000000000000000000000000000000000..18f0ffa7e8f4ce7086b894059a72a906938edfa8 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q81.sql @@ -0,0 +1,38 @@ +WITH customer_total_return AS +(SELECT + cr_returning_customer_sk AS ctr_customer_sk, + ca_state AS ctr_state, + sum(cr_return_amt_inc_tax) AS ctr_total_return + FROM catalog_returns, date_dim, customer_address + WHERE cr_returned_date_sk = d_date_sk + AND d_year = 2000 + AND cr_returning_addr_sk = ca_address_sk + GROUP BY cr_returning_customer_sk, ca_state ) +SELECT + c_customer_id, + c_salutation, + c_first_name, + c_last_name, + ca_street_number, + ca_street_name, + ca_street_type, + ca_suite_number, + ca_city, + ca_county, + ca_state, + ca_zip, + ca_country, + ca_gmt_offset, + ca_location_type, + ctr_total_return +FROM customer_total_return ctr1, customer_address, customer +WHERE ctr1.ctr_total_return > (SELECT avg(ctr_total_return) * 1.2 +FROM customer_total_return ctr2 +WHERE ctr1.ctr_state = ctr2.ctr_state) + AND ca_address_sk = c_current_addr_sk + AND ca_state = 'GA' + AND ctr1.ctr_customer_sk = c_customer_sk +ORDER BY c_customer_id, c_salutation, c_first_name, c_last_name, ca_street_number, ca_street_name + , ca_street_type, ca_suite_number, ca_city, ca_county, ca_state, ca_zip, ca_country, ca_gmt_offset + , ca_location_type, ctr_total_return +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q82.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q82.sql new file mode 100755 index 0000000000000000000000000000000000000000..20942cfeb0787c764cca21a87f4dcf8c5f07a9f3 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q82.sql @@ -0,0 +1,15 @@ +SELECT + i_item_id, + i_item_desc, + i_current_price +FROM item, inventory, date_dim, store_sales +WHERE i_current_price BETWEEN 62 AND 62 + 30 + AND inv_item_sk = i_item_sk + AND d_date_sk = inv_date_sk + AND d_date BETWEEN cast('2000-05-25' AS DATE) AND (cast('2000-05-25' AS DATE) + INTERVAL 60 days) + AND i_manufact_id IN (129, 270, 821, 423) + AND inv_quantity_on_hand BETWEEN 100 AND 500 + AND ss_item_sk = i_item_sk +GROUP BY i_item_id, i_item_desc, i_current_price +ORDER BY i_item_id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q83.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q83.sql new file mode 100755 index 0000000000000000000000000000000000000000..53c10c7ded6c19322e822b50df7bfc677277eeca --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q83.sql @@ -0,0 +1,56 @@ +WITH sr_items AS +(SELECT + i_item_id item_id, + sum(sr_return_quantity) sr_item_qty + FROM store_returns, item, date_dim + WHERE sr_item_sk = i_item_sk + AND d_date IN (SELECT d_date + FROM date_dim + WHERE d_week_seq IN + (SELECT d_week_seq + FROM date_dim + WHERE d_date IN ('2000-06-30', '2000-09-27', '2000-11-17'))) + AND sr_returned_date_sk = d_date_sk + GROUP BY i_item_id), + cr_items AS + (SELECT + i_item_id item_id, + sum(cr_return_quantity) cr_item_qty + FROM catalog_returns, item, date_dim + WHERE cr_item_sk = i_item_sk + AND d_date IN (SELECT d_date + FROM date_dim + WHERE d_week_seq IN + (SELECT d_week_seq + FROM date_dim + WHERE d_date IN ('2000-06-30', '2000-09-27', '2000-11-17'))) + AND cr_returned_date_sk = d_date_sk + GROUP BY i_item_id), + wr_items AS + (SELECT + i_item_id item_id, + sum(wr_return_quantity) wr_item_qty + FROM web_returns, item, date_dim + WHERE wr_item_sk = i_item_sk AND d_date IN + (SELECT d_date + FROM date_dim + WHERE d_week_seq IN + (SELECT d_week_seq + FROM date_dim + WHERE d_date IN ('2000-06-30', '2000-09-27', '2000-11-17'))) + AND wr_returned_date_sk = d_date_sk + GROUP BY i_item_id) +SELECT + sr_items.item_id, + sr_item_qty, + sr_item_qty / (sr_item_qty + cr_item_qty + wr_item_qty) / 3.0 * 100 sr_dev, + cr_item_qty, + cr_item_qty / (sr_item_qty + cr_item_qty + wr_item_qty) / 3.0 * 100 cr_dev, + wr_item_qty, + wr_item_qty / (sr_item_qty + cr_item_qty + wr_item_qty) / 3.0 * 100 wr_dev, + (sr_item_qty + cr_item_qty + wr_item_qty) / 3.0 average +FROM sr_items, cr_items, wr_items +WHERE sr_items.item_id = cr_items.item_id + AND sr_items.item_id = wr_items.item_id +ORDER BY sr_items.item_id, sr_item_qty +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q84.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q84.sql new file mode 100755 index 0000000000000000000000000000000000000000..a1076b57ced5c44c3ee5b96f94ec1b546e596901 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q84.sql @@ -0,0 +1,19 @@ +SELECT + c_customer_id AS customer_id, + concat(c_last_name, ', ', c_first_name) AS customername +FROM customer + , customer_address + , customer_demographics + , household_demographics + , income_band + , store_returns +WHERE ca_city = 'Edgewood' + AND c_current_addr_sk = ca_address_sk + AND ib_lower_bound >= 38128 + AND ib_upper_bound <= 38128 + 50000 + AND ib_income_band_sk = hd_income_band_sk + AND cd_demo_sk = c_current_cdemo_sk + AND hd_demo_sk = c_current_hdemo_sk + AND sr_cdemo_sk = cd_demo_sk +ORDER BY c_customer_id +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q85.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q85.sql new file mode 100755 index 0000000000000000000000000000000000000000..cf718b0f8adec9deb62c17e38b17cf5f4ad5af10 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q85.sql @@ -0,0 +1,82 @@ +SELECT + substr(r_reason_desc, 1, 20), + avg(ws_quantity), + avg(wr_refunded_cash), + avg(wr_fee) +FROM web_sales, web_returns, web_page, customer_demographics cd1, + customer_demographics cd2, customer_address, date_dim, reason +WHERE ws_web_page_sk = wp_web_page_sk + AND ws_item_sk = wr_item_sk + AND ws_order_number = wr_order_number + AND ws_sold_date_sk = d_date_sk AND d_year = 2000 + AND cd1.cd_demo_sk = wr_refunded_cdemo_sk + AND cd2.cd_demo_sk = wr_returning_cdemo_sk + AND ca_address_sk = wr_refunded_addr_sk + AND r_reason_sk = wr_reason_sk + AND + ( + ( + cd1.cd_marital_status = 'M' + AND + cd1.cd_marital_status = cd2.cd_marital_status + AND + cd1.cd_education_status = 'Advanced Degree' + AND + cd1.cd_education_status = cd2.cd_education_status + AND + ws_sales_price BETWEEN 100.00 AND 150.00 + ) + OR + ( + cd1.cd_marital_status = 'S' + AND + cd1.cd_marital_status = cd2.cd_marital_status + AND + cd1.cd_education_status = 'College' + AND + cd1.cd_education_status = cd2.cd_education_status + AND + ws_sales_price BETWEEN 50.00 AND 100.00 + ) + OR + ( + cd1.cd_marital_status = 'W' + AND + cd1.cd_marital_status = cd2.cd_marital_status + AND + cd1.cd_education_status = '2 yr Degree' + AND + cd1.cd_education_status = cd2.cd_education_status + AND + ws_sales_price BETWEEN 150.00 AND 200.00 + ) + ) + AND + ( + ( + ca_country = 'United States' + AND + ca_state IN ('IN', 'OH', 'NJ') + AND ws_net_profit BETWEEN 100 AND 200 + ) + OR + ( + ca_country = 'United States' + AND + ca_state IN ('WI', 'CT', 'KY') + AND ws_net_profit BETWEEN 150 AND 300 + ) + OR + ( + ca_country = 'United States' + AND + ca_state IN ('LA', 'IA', 'AR') + AND ws_net_profit BETWEEN 50 AND 250 + ) + ) +GROUP BY r_reason_desc +ORDER BY substr(r_reason_desc, 1, 20) + , avg(ws_quantity) + , avg(wr_refunded_cash) + , avg(wr_fee) +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q86.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q86.sql new file mode 100755 index 0000000000000000000000000000000000000000..789a4abf7b5f7c99bad1fbd1190f3b8dcc9102bf --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q86.sql @@ -0,0 +1,24 @@ +SELECT + sum(ws_net_paid) AS total_sum, + i_category, + i_class, + grouping(i_category) + grouping(i_class) AS lochierarchy, + rank() + OVER ( + PARTITION BY grouping(i_category) + grouping(i_class), + CASE WHEN grouping(i_class) = 0 + THEN i_category END + ORDER BY sum(ws_net_paid) DESC) AS rank_within_parent +FROM + web_sales, date_dim d1, item +WHERE + d1.d_month_seq BETWEEN 1200 AND 1200 + 11 + AND d1.d_date_sk = ws_sold_date_sk + AND i_item_sk = ws_item_sk +GROUP BY ROLLUP (i_category, i_class) +ORDER BY + lochierarchy DESC, + CASE WHEN lochierarchy = 0 + THEN i_category END, + rank_within_parent +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q87.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q87.sql new file mode 100755 index 0000000000000000000000000000000000000000..4aaa9f39dce9ee1c195878d8005665ac51788993 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q87.sql @@ -0,0 +1,28 @@ +SELECT count(*) +FROM ((SELECT DISTINCT + c_last_name, + c_first_name, + d_date +FROM store_sales, date_dim, customer +WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_customer_sk = customer.c_customer_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11) + EXCEPT + (SELECT DISTINCT + c_last_name, + c_first_name, + d_date + FROM catalog_sales, date_dim, customer + WHERE catalog_sales.cs_sold_date_sk = date_dim.d_date_sk + AND catalog_sales.cs_bill_customer_sk = customer.c_customer_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11) + EXCEPT + (SELECT DISTINCT + c_last_name, + c_first_name, + d_date + FROM web_sales, date_dim, customer + WHERE web_sales.ws_sold_date_sk = date_dim.d_date_sk + AND web_sales.ws_bill_customer_sk = customer.c_customer_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11) + ) cool_cust diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q88.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q88.sql new file mode 100755 index 0000000000000000000000000000000000000000..25bcd90f41ab67e81fb9da4639e2e7d4af7ef3f0 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q88.sql @@ -0,0 +1,122 @@ +SELECT * +FROM + (SELECT count(*) h8_30_to_9 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 8 + AND time_dim.t_minute >= 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s1, + (SELECT count(*) h9_to_9_30 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 9 + AND time_dim.t_minute < 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s2, + (SELECT count(*) h9_30_to_10 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 9 + AND time_dim.t_minute >= 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s3, + (SELECT count(*) h10_to_10_30 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 10 + AND time_dim.t_minute < 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s4, + (SELECT count(*) h10_30_to_11 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 10 + AND time_dim.t_minute >= 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s5, + (SELECT count(*) h11_to_11_30 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 11 + AND time_dim.t_minute < 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s6, + (SELECT count(*) h11_30_to_12 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 11 + AND time_dim.t_minute >= 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s7, + (SELECT count(*) h12_to_12_30 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 12 + AND time_dim.t_minute < 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s8 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q89.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q89.sql new file mode 100755 index 0000000000000000000000000000000000000000..75408cb0323f83f296f23769ece6a115b8b2c027 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q89.sql @@ -0,0 +1,30 @@ +SELECT * +FROM ( + SELECT + i_category, + i_class, + i_brand, + s_store_name, + s_company_name, + d_moy, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) + OVER + (PARTITION BY i_category, i_brand, s_store_name, s_company_name) + avg_monthly_sales + FROM item, store_sales, date_dim, store + WHERE ss_item_sk = i_item_sk AND + ss_sold_date_sk = d_date_sk AND + ss_store_sk = s_store_sk AND + d_year IN (1999) AND + ((i_category IN ('Books', 'Electronics', 'Sports') AND + i_class IN ('computers', 'stereo', 'football')) + OR (i_category IN ('Men', 'Jewelry', 'Women') AND + i_class IN ('shirts', 'birdal', 'dresses'))) + GROUP BY i_category, i_class, i_brand, + s_store_name, s_company_name, d_moy) tmp1 +WHERE CASE WHEN (avg_monthly_sales <> 0) + THEN (abs(sum_sales - avg_monthly_sales) / avg_monthly_sales) + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, s_store_name +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q9.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q9.sql new file mode 100755 index 0000000000000000000000000000000000000000..de3db9d988f1ec02dc2a46329eb5164782eaf6c8 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q9.sql @@ -0,0 +1,48 @@ +SELECT + CASE WHEN (SELECT count(*) + FROM store_sales + WHERE ss_quantity BETWEEN 1 AND 20) > 62316685 + THEN (SELECT avg(ss_ext_discount_amt) + FROM store_sales + WHERE ss_quantity BETWEEN 1 AND 20) + ELSE (SELECT avg(ss_net_paid) + FROM store_sales + WHERE ss_quantity BETWEEN 1 AND 20) END bucket1, + CASE WHEN (SELECT count(*) + FROM store_sales + WHERE ss_quantity BETWEEN 21 AND 40) > 19045798 + THEN (SELECT avg(ss_ext_discount_amt) + FROM store_sales + WHERE ss_quantity BETWEEN 21 AND 40) + ELSE (SELECT avg(ss_net_paid) + FROM store_sales + WHERE ss_quantity BETWEEN 21 AND 40) END bucket2, + CASE WHEN (SELECT count(*) + FROM store_sales + WHERE ss_quantity BETWEEN 41 AND 60) > 365541424 + THEN (SELECT avg(ss_ext_discount_amt) + FROM store_sales + WHERE ss_quantity BETWEEN 41 AND 60) + ELSE (SELECT avg(ss_net_paid) + FROM store_sales + WHERE ss_quantity BETWEEN 41 AND 60) END bucket3, + CASE WHEN (SELECT count(*) + FROM store_sales + WHERE ss_quantity BETWEEN 61 AND 80) > 216357808 + THEN (SELECT avg(ss_ext_discount_amt) + FROM store_sales + WHERE ss_quantity BETWEEN 61 AND 80) + ELSE (SELECT avg(ss_net_paid) + FROM store_sales + WHERE ss_quantity BETWEEN 61 AND 80) END bucket4, + CASE WHEN (SELECT count(*) + FROM store_sales + WHERE ss_quantity BETWEEN 81 AND 100) > 184483884 + THEN (SELECT avg(ss_ext_discount_amt) + FROM store_sales + WHERE ss_quantity BETWEEN 81 AND 100) + ELSE (SELECT avg(ss_net_paid) + FROM store_sales + WHERE ss_quantity BETWEEN 81 AND 100) END bucket5 +FROM reason +WHERE r_reason_sk = 1 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q90.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q90.sql new file mode 100755 index 0000000000000000000000000000000000000000..85e35bf8bf8ecd6677283092d80565fcd5ac3ec8 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q90.sql @@ -0,0 +1,19 @@ +SELECT cast(amc AS DECIMAL(15, 4)) / cast(pmc AS DECIMAL(15, 4)) am_pm_ratio +FROM (SELECT count(*) amc +FROM web_sales, household_demographics, time_dim, web_page +WHERE ws_sold_time_sk = time_dim.t_time_sk + AND ws_ship_hdemo_sk = household_demographics.hd_demo_sk + AND ws_web_page_sk = web_page.wp_web_page_sk + AND time_dim.t_hour BETWEEN 8 AND 8 + 1 + AND household_demographics.hd_dep_count = 6 + AND web_page.wp_char_count BETWEEN 5000 AND 5200) at, + (SELECT count(*) pmc + FROM web_sales, household_demographics, time_dim, web_page + WHERE ws_sold_time_sk = time_dim.t_time_sk + AND ws_ship_hdemo_sk = household_demographics.hd_demo_sk + AND ws_web_page_sk = web_page.wp_web_page_sk + AND time_dim.t_hour BETWEEN 19 AND 19 + 1 + AND household_demographics.hd_dep_count = 6 + AND web_page.wp_char_count BETWEEN 5000 AND 5200) pt +ORDER BY am_pm_ratio +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q91.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q91.sql new file mode 100755 index 0000000000000000000000000000000000000000..9ca7ce00ac775848c506121ab2884312240ee59c --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q91.sql @@ -0,0 +1,23 @@ +SELECT + cc_call_center_id Call_Center, + cc_name Call_Center_Name, + cc_manager Manager, + sum(cr_net_loss) Returns_Loss +FROM + call_center, catalog_returns, date_dim, customer, customer_address, + customer_demographics, household_demographics +WHERE + cr_call_center_sk = cc_call_center_sk + AND cr_returned_date_sk = d_date_sk + AND cr_returning_customer_sk = c_customer_sk + AND cd_demo_sk = c_current_cdemo_sk + AND hd_demo_sk = c_current_hdemo_sk + AND ca_address_sk = c_current_addr_sk + AND d_year = 1998 + AND d_moy = 11 + AND ((cd_marital_status = 'M' AND cd_education_status = 'Unknown') + OR (cd_marital_status = 'W' AND cd_education_status = 'Advanced Degree')) + AND hd_buy_potential LIKE 'Unknown%' + AND ca_gmt_offset = -7 +GROUP BY cc_call_center_id, cc_name, cc_manager, cd_marital_status, cd_education_status +ORDER BY sum(cr_net_loss) DESC diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q92.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q92.sql new file mode 100755 index 0000000000000000000000000000000000000000..99129c3bd9e5bf36295ce84074257c14031d4844 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q92.sql @@ -0,0 +1,16 @@ +SELECT sum(ws_ext_discount_amt) AS `Excess Discount Amount ` +FROM web_sales, item, date_dim +WHERE i_manufact_id = 350 + AND i_item_sk = ws_item_sk + AND d_date BETWEEN '2000-01-27' AND (cast('2000-01-27' AS DATE) + INTERVAL 90 days) + AND d_date_sk = ws_sold_date_sk + AND ws_ext_discount_amt > + ( + SELECT 1.3 * avg(ws_ext_discount_amt) + FROM web_sales, date_dim + WHERE ws_item_sk = i_item_sk + AND d_date BETWEEN '2000-01-27' AND (cast('2000-01-27' AS DATE) + INTERVAL 90 days) + AND d_date_sk = ws_sold_date_sk + ) +ORDER BY sum(ws_ext_discount_amt) +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q93.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q93.sql new file mode 100755 index 0000000000000000000000000000000000000000..222dc31c1f561d25f26721ad8cdb63f7aed075f0 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q93.sql @@ -0,0 +1,19 @@ +SELECT + ss_customer_sk, + sum(act_sales) sumsales +FROM (SELECT + ss_item_sk, + ss_ticket_number, + ss_customer_sk, + CASE WHEN sr_return_quantity IS NOT NULL + THEN (ss_quantity - sr_return_quantity) * ss_sales_price + ELSE (ss_quantity * ss_sales_price) END act_sales +FROM store_sales + LEFT OUTER JOIN store_returns + ON (sr_item_sk = ss_item_sk AND sr_ticket_number = ss_ticket_number) + , + reason +WHERE sr_reason_sk = r_reason_sk AND r_reason_desc = 'reason 28') t +GROUP BY ss_customer_sk +ORDER BY sumsales, ss_customer_sk +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q94.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q94.sql new file mode 100755 index 0000000000000000000000000000000000000000..d6de3d75b82d2b1cfcbafdc59eb12dfb54e0cbc7 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q94.sql @@ -0,0 +1,23 @@ +SELECT + count(DISTINCT ws_order_number) AS `order count `, + sum(ws_ext_ship_cost) AS `total shipping cost `, + sum(ws_net_profit) AS `total net profit ` +FROM + web_sales ws1, date_dim, customer_address, web_site +WHERE + d_date BETWEEN '1999-02-01' AND + (CAST('1999-02-01' AS DATE) + INTERVAL 60 days) + AND ws1.ws_ship_date_sk = d_date_sk + AND ws1.ws_ship_addr_sk = ca_address_sk + AND ca_state = 'IL' + AND ws1.ws_web_site_sk = web_site_sk + AND web_company_name = 'pri' + AND EXISTS(SELECT * + FROM web_sales ws2 + WHERE ws1.ws_order_number = ws2.ws_order_number + AND ws1.ws_warehouse_sk <> ws2.ws_warehouse_sk) + AND NOT EXISTS(SELECT * + FROM web_returns wr1 + WHERE ws1.ws_order_number = wr1.wr_order_number) +ORDER BY count(DISTINCT ws_order_number) +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q95.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q95.sql new file mode 100755 index 0000000000000000000000000000000000000000..df71f00bd6c0b4ea9c48262b806af243420eab94 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q95.sql @@ -0,0 +1,29 @@ +WITH ws_wh AS +(SELECT + ws1.ws_order_number, + ws1.ws_warehouse_sk wh1, + ws2.ws_warehouse_sk wh2 + FROM web_sales ws1, web_sales ws2 + WHERE ws1.ws_order_number = ws2.ws_order_number + AND ws1.ws_warehouse_sk <> ws2.ws_warehouse_sk) +SELECT + count(DISTINCT ws_order_number) AS `order count `, + sum(ws_ext_ship_cost) AS `total shipping cost `, + sum(ws_net_profit) AS `total net profit ` +FROM + web_sales ws1, date_dim, customer_address, web_site +WHERE + d_date BETWEEN '1999-02-01' AND + (CAST('1999-02-01' AS DATE) + INTERVAL 60 DAY) + AND ws1.ws_ship_date_sk = d_date_sk + AND ws1.ws_ship_addr_sk = ca_address_sk + AND ca_state = 'IL' + AND ws1.ws_web_site_sk = web_site_sk + AND web_company_name = 'pri' + AND ws1.ws_order_number IN (SELECT ws_order_number + FROM ws_wh) + AND ws1.ws_order_number IN (SELECT wr_order_number + FROM web_returns, ws_wh + WHERE wr_order_number = ws_wh.ws_order_number) +ORDER BY count(DISTINCT ws_order_number) +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q96.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q96.sql new file mode 100755 index 0000000000000000000000000000000000000000..7ab17e7bc4597731560565d37465a07ecdfdd35e --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q96.sql @@ -0,0 +1,11 @@ +SELECT count(*) +FROM store_sales, household_demographics, time_dim, store +WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 20 + AND time_dim.t_minute >= 30 + AND household_demographics.hd_dep_count = 7 + AND store.s_store_name = 'ese' +ORDER BY count(*) +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q97.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q97.sql new file mode 100755 index 0000000000000000000000000000000000000000..e7e0b1a05259da58ee6fedbbf971bcfc0f67384c --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q97.sql @@ -0,0 +1,30 @@ +WITH ssci AS ( + SELECT + ss_customer_sk customer_sk, + ss_item_sk item_sk + FROM store_sales, date_dim + WHERE ss_sold_date_sk = d_date_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + GROUP BY ss_customer_sk, ss_item_sk), + csci AS ( + SELECT + cs_bill_customer_sk customer_sk, + cs_item_sk item_sk + FROM catalog_sales, date_dim + WHERE cs_sold_date_sk = d_date_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + GROUP BY cs_bill_customer_sk, cs_item_sk) +SELECT + sum(CASE WHEN ssci.customer_sk IS NOT NULL AND csci.customer_sk IS NULL + THEN 1 + ELSE 0 END) store_only, + sum(CASE WHEN ssci.customer_sk IS NULL AND csci.customer_sk IS NOT NULL + THEN 1 + ELSE 0 END) catalog_only, + sum(CASE WHEN ssci.customer_sk IS NOT NULL AND csci.customer_sk IS NOT NULL + THEN 1 + ELSE 0 END) store_and_catalog +FROM ssci + FULL OUTER JOIN csci ON (ssci.customer_sk = csci.customer_sk + AND ssci.item_sk = csci.item_sk) +LIMIT 100 diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q98.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q98.sql new file mode 100755 index 0000000000000000000000000000000000000000..bb10d4bf8da23d6ffec3917b85646acdfdea4176 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q98.sql @@ -0,0 +1,21 @@ +SELECT + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ss_ext_sales_price) AS itemrevenue, + sum(ss_ext_sales_price) * 100 / sum(sum(ss_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM + store_sales, item, date_dim +WHERE + ss_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND ss_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) + AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY + i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY + i_category, i_class, i_item_id, i_item_desc, revenueratio diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q99.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q99.sql new file mode 100755 index 0000000000000000000000000000000000000000..f1a3d4d2b7fe9bc41fe7a5ded0490566bc92f14f --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds/q99.sql @@ -0,0 +1,34 @@ +SELECT + substr(w_warehouse_name, 1, 20), + sm_type, + cc_name, + sum(CASE WHEN (cs_ship_date_sk - cs_sold_date_sk <= 30) + THEN 1 + ELSE 0 END) AS `30 days `, + sum(CASE WHEN (cs_ship_date_sk - cs_sold_date_sk > 30) AND + (cs_ship_date_sk - cs_sold_date_sk <= 60) + THEN 1 + ELSE 0 END) AS `31 - 60 days `, + sum(CASE WHEN (cs_ship_date_sk - cs_sold_date_sk > 60) AND + (cs_ship_date_sk - cs_sold_date_sk <= 90) + THEN 1 + ELSE 0 END) AS `61 - 90 days `, + sum(CASE WHEN (cs_ship_date_sk - cs_sold_date_sk > 90) AND + (cs_ship_date_sk - cs_sold_date_sk <= 120) + THEN 1 + ELSE 0 END) AS `91 - 120 days `, + sum(CASE WHEN (cs_ship_date_sk - cs_sold_date_sk > 120) + THEN 1 + ELSE 0 END) AS `>120 days ` +FROM + catalog_sales, warehouse, ship_mode, call_center, date_dim +WHERE + d_month_seq BETWEEN 1200 AND 1200 + 11 + AND cs_ship_date_sk = d_date_sk + AND cs_warehouse_sk = w_warehouse_sk + AND cs_ship_mode_sk = sm_ship_mode_sk + AND cs_call_center_sk = cc_call_center_sk +GROUP BY + substr(w_warehouse_name, 1, 20), sm_type, cc_name +ORDER BY substr(w_warehouse_name, 1, 20), sm_type, cc_name +LIMIT 100 diff --git a/omnicache/omnicache-spark-extension/plugin/src/test/resources/tpcds_ddl.sql b/omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds_ddl.sql similarity index 100% rename from omnicache/omnicache-spark-extension/plugin/src/test/resources/tpcds_ddl.sql rename to omnimv/omnimv-spark-extension/plugin/src/test/resources/tpcds_ddl.sql diff --git a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewAggregateRuleSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewAggregateRuleSuite.scala similarity index 80% rename from omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewAggregateRuleSuite.scala rename to omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewAggregateRuleSuite.scala index 632d21b6ddc173a190e4669013ff48fc4da79162..817a8d215eafd3ad7cece6fe2dd4c5ac03b536dd 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewAggregateRuleSuite.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewAggregateRuleSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.optimizer.rules +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ + class MaterializedViewAggregateRuleSuite extends RewriteSuite { test("mv_agg1") { @@ -178,7 +180,6 @@ class MaterializedViewAggregateRuleSuite extends RewriteSuite { spark.sql(sql).show() } - test("mv_agg4") { spark.sql( """ @@ -460,4 +461,148 @@ class MaterializedViewAggregateRuleSuite extends RewriteSuite { |""".stripMargin ) } + + test("mv_agg8_1") { + // Aggregation hence(The group by field is different): + // min(distinct ) / max(distinct ) + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_agg8_1; + |""".stripMargin + ) + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv_agg8_1 + |AS + |SELECT + |c.deptno, + |c.locationid, + |max(c.longtype) as _max, + |min(c.floattype) as _min + |FROM column_type c + |GROUP BY c.empid,c.deptno,c.locationid; + |""".stripMargin + ) + val sql = + """ + |SELECT + |max(c.longtype) as _max, + |min(c.floattype) as _min + |FROM column_type c + |GROUP BY c.deptno,c.locationid; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_agg8_1", noData = false) + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_agg8_1; + |""".stripMargin + ) + } + + test("mv_agg8_2") { + // Aggregation hence(The group by field is different): + // avg() + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_agg8_2; + |""".stripMargin + ) + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv_agg8_2 + |AS + |SELECT + |c.deptno, + |c.locationid, + |avg(c.longtype) as _avg, + |count(c.longtype) as _count + |FROM column_type c + |GROUP BY c.empid,c.deptno,c.locationid; + |""".stripMargin + ) + val sql = + """ + |SELECT + |avg(c.longtype) as _avg + |FROM column_type c + |GROUP BY c.deptno,c.locationid; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_agg8_2", noData = false) + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_agg8_2; + |""".stripMargin + ) + } + + // min(distinct)/max(distinct)/avg() enhance + test("mv_agg9") { + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_agg9; + |""".stripMargin + ) + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv_agg9 + |AS + |SELECT c.empid,c.deptno,c.locationid, + |min(distinct c.integertype) as _min_dist, + |max(distinct c.longtype) as _max_dist, + |count(c.decimaltype) as _count, + |avg(c.decimaltype) as _avg + |FROM column_type c JOIN emps e + |ON c.empid=e.empid + |AND c.empid=1 + |GROUP BY c.empid,c.deptno,c.locationid; + |""".stripMargin + ) + } + + test("mv_agg9_1") { + val sql = + """ + |SELECT c.empid,c.deptno, + |min(distinct c.integertype) as _min_dist, + |max(distinct c.longtype) as _max_dist, + |count(c.decimaltype) as _count, + |avg(c.decimaltype) as _avg + |FROM column_type c JOIN emps e + |ON c.empid=e.empid + |AND c.empid=1 + |GROUP BY c.empid,c.deptno; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_agg9", noData = true) + } + + test("drop_mv_agg9") { + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_agg9; + |""".stripMargin + ) + } + + test("drop all mv") { + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_agg1; + |""".stripMargin + ) + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_agg2; + |""".stripMargin + ) + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_agg3; + |""".stripMargin + ) + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_agg4; + |""".stripMargin + ) + } } diff --git a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewFilterRuleSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewFilterRuleSuite.scala similarity index 99% rename from omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewFilterRuleSuite.scala rename to omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewFilterRuleSuite.scala index f34e2b74e2bbdadeca072d6a0166cb030e3433b0..5b454d70ffa9fc49747cf5395c89c6cea78fbe09 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewFilterRuleSuite.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewFilterRuleSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.optimizer.rules import com.huawei.boostkit.spark.util.RewriteHelper +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ + class MaterializedViewFilterRuleSuite extends RewriteSuite { test("mv_filter1") { diff --git a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewJoinRuleSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewJoinRuleSuite.scala similarity index 61% rename from omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewJoinRuleSuite.scala rename to omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewJoinRuleSuite.scala index da9e4faf2961fd039181e9e449d044efcded889b..648fecd5c52e0088c6d307fead27d7911d02b1c0 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewJoinRuleSuite.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewJoinRuleSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.optimizer.rules import com.huawei.boostkit.spark.util.RewriteHelper +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ + class MaterializedViewJoinRuleSuite extends RewriteSuite { test("mv_join1") { @@ -49,6 +51,114 @@ class MaterializedViewJoinRuleSuite extends RewriteSuite { comparePlansAndRows(sql, "default", "mv_join1", noData = false) } + test("mv_join1_1_subQuery") { + // is same to view + val sql = + """ + |SELECT e.*,d.deptname1 + |FROM + | (SELECT + | empid as empid, + | deptno as deptno, + | empname as empname1 + | FROM + | emps) + | e JOIN + | (SELECT + | deptno as deptno, + | deptname as deptname1 + | FROM + | depts) + | d + |ON e.deptno=d.deptno; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_join1", noData = false) + } + + test("mv_join1_1_subQuery2") { + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv_join1_subQuery2 + |AS + |SELECT e.*,d.deptname + |FROM emps e JOIN depts d + |ON substring(e.deptno,0,1) =substring(d.deptno,0,1) ; + |""".stripMargin + ) + + // is same to view + val sql = + """ + |SELECT e.*,d.deptname1 + |FROM + | (SELECT + | empid as empid, + | substring(deptno,0,1) as deptno, + | empname as empname1 + | FROM + | emps) + | e JOIN + | (SELECT + | substring(deptno,0,1) as deptno, + | deptname as deptname1 + | FROM + | depts) + | d + |ON e.deptno=d.deptno; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_join1_subQuery2", noData = false) + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv_join1_subQuery2") + } + + test("mv_join1_1_subQuery3") { + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv_join1_subQuery3 + |AS + |SELECT e.*,d.deptname,l.state + |FROM emps e JOIN depts d JOIN locations l + |ON substring(e.deptno,0,1) =substring(d.deptno,0,1) + |AND e.locationid=l.locationid; + |""".stripMargin + ) + + // is same to view + val sql = + """ + |SELECT k.*,l.state + |FROM + |(SELECT e.*,d.deptname1 + |FROM + | (SELECT + | empid as empid, + | substring(deptno,0,1) as deptno, + | empname as empname1, + | locationid as locationid + | FROM + | emps) + | e JOIN + | (SELECT + | substring(deptno,0,1) as deptno, + | deptname as deptname1 + | FROM + | depts) + | d + |ON e.deptno=d.deptno) + |k JOIN + |(SELECT + |locationid as locationid, + |state as state + |FROM + |locations + |) + |l + |ON k.locationid=l.locationid + |; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_join1_subQuery3", noData = false) + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv_join1_subQuery3") + } + test("mv_join1_2") { // is same to view, join order different val sql = @@ -167,4 +277,17 @@ class MaterializedViewJoinRuleSuite extends RewriteSuite { val sql = "ALTER MATERIALIZED VIEW mv_join2 DISABLE REWRITE;" spark.sql(sql).show() } + + test("join all mv") { + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_join1; + |""".stripMargin + ) + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_join2; + |""".stripMargin + ) + } } diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewLeftJoinRuleSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewLeftJoinRuleSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..cede5a9e7736a7f2116290690f368e1b109fc6a0 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewLeftJoinRuleSuite.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer.rules + +import com.huawei.boostkit.spark.util.RewriteHelper.{disableCachePlugin, enableCachePlugin} + +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ + +class MaterializedViewLeftJoinRuleSuite extends RewriteSuite { + + test("mv_left_join") { + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_left_join; + |""".stripMargin + ) + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv_left_join + |AS + |SELECT e.*,d.deptname + |FROM emps e LEFT JOIN depts d + |ON e.deptno=d.deptno where e.deptno >= 2; + |""".stripMargin + ) + } + + test("mv_left_join_1") { + // is same to view + val sql = + """ + |SELECT e.*,d.deptname + |FROM emps e LEFT JOIN depts d + |ON e.deptno=d.deptno where e.deptno >= 2; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_left_join", noData = false) + } + + test("mv_left_join_2") { + // view tables is subset of query + val sql = + """ + |SELECT e.*,d.deptname, l.locationid + |FROM emps e LEFT JOIN depts d ON e.deptno=d.deptno JOIN locations l + |ON e.locationid=l.locationid where e.deptno >= 2; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_left_join", noData = false) + } + + test("mv_left_join_3") { + // view tables is subset of query + val sql = + """ + |SELECT e.*,d.deptname, l.locationid + |FROM emps e LEFT JOIN depts d ON e.deptno=d.deptno JOIN locations l + |ON e.locationid=l.locationid where e.deptno >= 5; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_left_join", noData = true) + } + + test("mv_left_join_cannot_rewrite") { + val sql = + """ + |SELECT e1.*,d.deptname,e2.* + |FROM emps e1 LEFT JOIN depts d ON e1.deptno=d.deptno JOIN emps e2 + |on d.deptno = e2.deptno where e1.deptno >= 2; + |""".stripMargin + val df = spark.sql(sql) + val optPlan = df.queryExecution.optimizedPlan + disableCachePlugin() + val df2 = spark.sql(sql) + val srcPlan = df2.queryExecution.optimizedPlan + enableCachePlugin() + assert(optPlan.toString().replaceAll("#\\d+", "") + .equals(srcPlan.toString().replaceAll("#\\d+", ""))) + } + + test("mv_left_join_4") { + // view tables is subset of query, join with subquery + val sql = + """ + |SELECT v1.*,l.locationid + |FROM + |(SELECT e.*,d.deptname + |FROM emps e LEFT JOIN depts d + |ON e.deptno=d.deptno where e.deptno >= 2 + |) v1 + |JOIN locations l + |ON v1.locationid=l.locationid; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_left_join", noData = false) + } + + test("mv_left_join_5") { + // view tables is same to query, equal columns + val sql = + """ + |SELECT d.deptname + |FROM emps e LEFT JOIN depts d + |ON e.deptno=d.deptno where e.deptno >= 2; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_left_join", noData = false) + } + + test("left_join_range1") { + // where的条件范围比视图大,不能重写 + val sql = + """ + |SELECT e.*,d.deptname, l.locationid + |FROM emps e LEFT JOIN depts d ON e.deptno=d.deptno JOIN locations l + |ON e.locationid=l.locationid where e.deptno > 0; + |""".stripMargin + val df = spark.sql(sql) + val optPlan = df.queryExecution.optimizedPlan + disableCachePlugin() + val df2 = spark.sql(sql) + val srcPlan = df2.queryExecution.optimizedPlan + enableCachePlugin() + assert(optPlan.toString().replaceAll("#\\d+", "") + .equals(srcPlan.toString().replaceAll("#\\d+", ""))) + } + + test("left_join_range2") { + // where的条件范围比视图小,可以重写 + val sql = + """ + |SELECT e.*,d.deptname, l.locationid + |FROM emps e LEFT JOIN depts d ON e.deptno=d.deptno JOIN locations l + |ON e.locationid=l.locationid where e.deptno > 2; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_left_join", noData = true) + } + + test("clean_env") { + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_left_join; + |""".stripMargin + ) + } +} diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewLeftSemiJoinRuleSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewLeftSemiJoinRuleSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..0ef960c4211074f7e8e08628693b9205111b4553 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewLeftSemiJoinRuleSuite.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer.rules + +import com.huawei.boostkit.spark.util.RewriteHelper.{disableCachePlugin, enableCachePlugin} + +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ + +class MaterializedViewLeftSemiJoinRuleSuite extends RewriteSuite { + + test("mv_left_semi_join") { + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_left_semi_join; + |""".stripMargin + ) + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv_left_semi_join + |AS + |SELECT e.* + |FROM emps e SEMI JOIN depts d + |ON e.deptno=d.deptno where e.deptno >= 2; + |""".stripMargin + ) + } + + test("mv_left_semi_join_1") { + // is same to view + val sql = + """ + |SELECT e.* + |FROM emps e SEMI JOIN depts d + |ON e.deptno=d.deptno where e.deptno >= 2; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_left_semi_join", noData = false) + } + + test("mv_left_semi_join_2") { + // view tables is subset of query + val sql = + """ + |SELECT e.*, l.locationid + |FROM emps e SEMI JOIN depts d ON e.deptno=d.deptno JOIN locations l + |ON e.locationid=l.locationid where e.deptno >= 2; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_left_semi_join", noData = false) + } + + test("mv_left_semi_join_3") { + // view tables is subset of query + val sql = + """ + |SELECT e.*, l.locationid + |FROM emps e SEMI JOIN depts d ON e.deptno=d.deptno JOIN locations l + |ON e.locationid=l.locationid where e.deptno = 5; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_left_semi_join", noData = true) + } + + test("mv_left_semi_join_cannot_rewrite") { + val sql = + """ + |SELECT e1.*,e2.* + |FROM emps e1 SEMI JOIN depts d ON e1.deptno=d.deptno JOIN emps e2 + |on e1.deptno = e2.deptno where e1.deptno >= 2; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_left_semi_join", noData = true) + } + + test("mv_left_semi_join_4") { + // view tables is subset of query, join with subquery + val sql = + """ + |SELECT v1.*,l.locationid + |FROM + |(SELECT e.* + |FROM emps e SEMI JOIN depts d + |ON e.deptno=d.deptno where e.deptno >= 2 + |) v1 + |JOIN locations l + |ON v1.locationid=l.locationid; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_left_semi_join", noData = false) + } + + test("mv_left_semi_join_5") { + // view tables is same to query, equal columns + val sql = + """ + |SELECT e.empname + |FROM emps e SEMI JOIN depts d + |ON e.deptno=d.deptno where e.deptno >= 2; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_left_semi_join", noData = false) + } + + test("left_semi_join_range1") { + // where的条件范围比视图大,不能重写 + val sql = + """ + |SELECT e.*, l.locationid + |FROM emps e SEMI JOIN depts d ON e.deptno=d.deptno JOIN locations l + |ON e.locationid=l.locationid where e.deptno > 0; + |""".stripMargin + val df = spark.sql(sql) + val optPlan = df.queryExecution.optimizedPlan + disableCachePlugin() + val df2 = spark.sql(sql) + val srcPlan = df2.queryExecution.optimizedPlan + enableCachePlugin() + assert(optPlan.toString().replaceAll("#\\d+", "") + .equals(srcPlan.toString().replaceAll("#\\d+", ""))) + } + + test("left_semi_join_range2") { + // where的条件范围比视图小,可以重写 + val sql = + """ + |SELECT e.*, l.locationid + |FROM emps e SEMI JOIN depts d ON e.deptno=d.deptno JOIN locations l + |ON e.locationid=l.locationid where e.deptno > 2; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_left_semi_join", noData = true) + } + + test("clean_env") { + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_left_semi_join; + |""".stripMargin + ) + } +} diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewOuterJoinRuleAggSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewOuterJoinRuleAggSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..fa2298667498a84735c0161cbacce98fcd64a10d --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewOuterJoinRuleAggSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer.rules + +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ + +class MaterializedViewOuterJoinRuleAggSuite extends OuterJoinSuite { + + test("create_agg_outJoin_view_0") { + def $1(joinType: String, viewNumber: Int): Unit = { + val joinName = joinType.replace(" ", "_") + spark.sql( + s""" + |DROP MATERIALIZED VIEW IF EXISTS ${joinName}_${viewNumber}; + |""".stripMargin + ) + spark.sql( + s""" + |CREATE MATERIALIZED VIEW IF NOT EXISTS ${joinName}_${viewNumber} + |AS + |SELECT e.empid, count(e.salary) + |FROM emps e + |${joinType} (select * from depts where deptno > 5 or deptno < 1) d ON e.deptno=d.deptno + |where e.deptno >= 2 OR e.deptno < 40 + |group by e.empid, e.locationid + |""".stripMargin + ) + } + + runOuterJoinFunc($1)(0) + } + + test("agg_outJoin_group_diff_0_0") { + def $1(joinType: String, viewNumber: Int): Unit = { + // is same to view but group by is different. + val joinName = joinType.replace(" ", "_") + val sql = + s""" + |SELECT e.empid, count(e.salary) + |FROM emps e + | ${joinType} (select * from depts where deptno > 5 or deptno < 1) d ON e.deptno=d.deptno + |where e.deptno >= 2 OR e.deptno < 40 + |group by e.empid + |""".stripMargin + comparePlansAndRows(sql, "default", s"${joinName}_${viewNumber}", noData = true) + } + + runOuterJoinFunc($1)(0) + } + + // It is not currently supported, but will be supported later. + test("agg_outJoin_group_diff_0_1") { + def $1(joinType: String, viewNumber: Int): Unit = { + // group by is different and query condition is subset of view condition. + val joinName = joinType.replace(" ", "_") + val sql = + s""" + |SELECT e.empid, count(e.salary) + |FROM emps e + | ${joinType} (select * from depts where deptno > 5) d ON e.deptno=d.deptno + |where e.deptno >= 2 + |group by e.empid + |""".stripMargin + compareNotRewriteAndRows(sql, noData = true) + } + + runOuterJoinFunc($1)(0) + } + + test("agg_outJoin_group_same_0_0") { + def $1(joinType: String, viewNumber: Int): Unit = { + // is same to view but group by is different. + val joinName = joinType.replace(" ", "_") + val sql = + s""" + |SELECT e.empid, count(e.salary) + |FROM emps e + | ${joinType} (select * from depts where deptno > 5 or deptno < 1) d ON e.deptno=d.deptno + |where e.deptno >= 2 OR e.deptno < 40 + |group by e.empid, e.locationid + |""".stripMargin + comparePlansAndRows(sql, "default", s"${joinName}_${viewNumber}", noData = true) + } + + runOuterJoinFunc($1)(0) + } + + test("clean_agg_outJoin_view_0") { + def $1(joinType: String, viewNumber: Int): Unit = { + val joinName = joinType.replace(" ", "_") + spark.sql( + s""" + |DROP MATERIALIZED VIEW IF EXISTS ${joinName}_${viewNumber}; + |""".stripMargin + ) + } + + runOuterJoinFunc($1)(0) + } +} diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewOuterJoinRuleProjectSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewOuterJoinRuleProjectSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..354b88aad971d4d8222b6d30e15ce7d0a4f87f39 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewOuterJoinRuleProjectSuite.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer.rules + +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ + +class MaterializedViewOuterJoinRuleProjectSuite extends OuterJoinSuite { + + test("create_project_outJoin_view_0") { + def $1(joinType: String, viewNumber: Int): Unit = { + val joinName = joinType.replace(" ", "_") + var leftTable = "(select * from depts where deptno > 50 or deptno < 5) d" + var rightTable = "emps e" + joinType match { + case "RIGHT JOIN" => + leftTable = "emps e" + rightTable = "(select * from depts where deptno > 50 or deptno < 5) d" + case "LEFT JOIN" | "SEMI JOIN" | "ANTI JOIN" => + case _ => + } + spark.sql( + s""" + |DROP MATERIALIZED VIEW IF EXISTS ${joinName}_${viewNumber}; + |""".stripMargin + ) + spark.sql( + s""" + |CREATE MATERIALIZED VIEW IF NOT EXISTS ${joinName}_${viewNumber} + |AS + |SELECT d.deptno + |FROM ${leftTable} + |${joinType} ${rightTable} + |ON e.deptno=d.deptno + |where d.deptno >= 40 OR d.deptno < 2; + |""".stripMargin + ) + } + + runOuterJoinFunc($1)(0) + } + + test("project_outJoin_filterCondition_compensate_0_0") { + def $1(joinType: String, viewNumber: Int): Unit = { + val joinName = joinType.replace(" ", "_") + var leftTable = "(select * from depts where deptno > 50 or deptno < 5) d" + var rightTable = "emps e" + joinType match { + case "RIGHT JOIN" => + leftTable = "emps e" + rightTable = "(select * from depts where deptno > 50 or deptno < 5) d" + case "LEFT JOIN" | "SEMI JOIN" | "ANTI JOIN" => + case _ => + } + val sql = + s""" + |SELECT d.deptno + |FROM ${leftTable} + |${joinType} ${rightTable} + |ON e.deptno=d.deptno + |where d.deptno >= 40; + |""".stripMargin + comparePlansAndRows(sql, "default", s"${joinName}_${viewNumber}", noData = true) + } + + runOuterJoinFunc($1)(0) + } + + test("project_outJoin_innerCondition_compensate_0_1") { + def $1(joinType: String, viewNumber: Int): Unit = { + val joinName = joinType.replace(" ", "_") + var leftTable = "(select * from depts where deptno > 50) d" + var rightTable = "emps e" + joinType match { + case "RIGHT JOIN" => + leftTable = "emps e" + rightTable = "(select * from depts where deptno > 50) d" + case "LEFT JOIN" | "SEMI JOIN" | "ANTI JOIN" => + case _ => + } + val sql = + s""" + |SELECT d.deptno + |FROM ${leftTable} + |${joinType} ${rightTable} + |ON e.deptno=d.deptno + |where d.deptno >= 40 OR d.deptno < 2; + |""".stripMargin + comparePlansAndRows(sql, "default", s"${joinName}_${viewNumber}", noData = true) + } + + runOuterJoinFunc($1)(0) + } + + test("project_outJoin_same_0_2") { + def $1(joinType: String, viewNumber: Int): Unit = { + // is same to view. + val joinName = joinType.replace(" ", "_") + var leftTable = "(select * from depts where deptno > 50 or deptno < 5) d" + var rightTable = "emps e" + joinType match { + case "RIGHT JOIN" => + leftTable = "emps e" + rightTable = "(select * from depts where deptno > 50 or deptno < 5) d" + case "LEFT JOIN" | "SEMI JOIN" | "ANTI JOIN" => + case _ => + } + val sql = + s""" + |SELECT d.deptno + |FROM ${leftTable} + |${joinType} ${rightTable} + |ON e.deptno=d.deptno + |where d.deptno >= 40 OR d.deptno < 2; + |""".stripMargin + comparePlansAndRows(sql, "default", s"${joinName}_${viewNumber}", noData = true) + } + + runOuterJoinFunc($1)(0) + } + + test("clean_project_outJoin_view_0") { + def $1(joinType: String, viewNumber: Int): Unit = { + val joinName = joinType.replace(" ", "_") + spark.sql( + s""" + |DROP MATERIALIZED VIEW IF EXISTS ${joinName}_${viewNumber}; + |""".stripMargin + ) + } + + runOuterJoinFunc($1)(0) + } + + test("create_project_outJoin_view_1") { + def $1(joinType: String, viewNumber: Int): Unit = { + val joinName = joinType.replace(" ", "_") + var leftTable = "(select * from depts where deptno > 50 or deptno < 5) d" + var rightTable = "emps e" + joinType match { + case "RIGHT JOIN" => + leftTable = "emps e" + rightTable = "(select * from depts where deptno > 50 or deptno < 5) d" + case "LEFT JOIN" | "SEMI JOIN" | "ANTI JOIN" => + case _ => + } + val leftAlias = leftTable.split(" ").last + spark.sql( + s""" + |DROP MATERIALIZED VIEW IF EXISTS ${joinName}_${viewNumber}; + |""".stripMargin + ) + spark.sql( + s""" + |CREATE MATERIALIZED VIEW IF NOT EXISTS ${joinName}_${viewNumber} + |AS + |SELECT d.deptno + |FROM locations l JOIN + |${leftTable} ON l.locationid = ${leftAlias}.deptno + |${joinType} ${rightTable} + |ON e.deptno=d.deptno or ${leftAlias}.deptno is not null + |where d.deptno >= 40 OR d.deptno < 2; + |""".stripMargin + ) + } + + runOuterJoinFunc($1)(1) + } + + /** + * The join of the view and the join of the query must match from scratch. + * Positive example: + * view: select * from a left join b join c right join d join e where ... + * query: select * from a left join b join c right join d where ... + * + * Bad example: + * view: select * from a left join b join c right join d join e where ... + * query: select * from b join c right join d where ... + */ + test("project_outJoin_MatchFromHead_1_0") { + def $1(joinType: String, viewNumber: Int): Unit = { + val joinName = joinType.replace(" ", "_") + var leftTable = "(select * from depts where deptno > 50 or deptno < 5) d" + var rightTable = "emps e" + joinType match { + case "RIGHT JOIN" => + leftTable = "emps e" + rightTable = "(select * from depts where deptno > 50 or deptno < 5) d" + case "LEFT JOIN" | "SEMI JOIN" | "ANTI JOIN" => + case _ => + } + val leftAlias = leftTable.split(" ").last + val sql = + s""" + |SELECT d.deptno + |FROM ${leftTable} + |${joinType} ${rightTable} + |ON e.deptno=d.deptno or ${leftAlias}.deptno is not null + |where d.deptno >= 40 OR d.deptno < 2; + |""".stripMargin + compareNotRewriteAndRows(sql, noData = true) + } + + runOuterJoinFunc($1)(1) + } + + // At present, the out join condition needs to be consistent, + // and the support with inconsistent condition may be carried out in the future + test("project_outJoin_OutJoinCondition_diff_1_1") { + def $1(joinType: String, viewNumber: Int): Unit = { + val joinName = joinType.replace(" ", "_") + var leftTable = "(select * from depts where deptno > 50 or deptno < 5) d" + var rightTable = "emps e" + joinType match { + case "RIGHT JOIN" => + leftTable = "emps e" + rightTable = "(select * from depts where deptno > 50 or deptno < 5) d" + case "LEFT JOIN" | "SEMI JOIN" | "ANTI JOIN" => + case _ => + } + val leftAlias = leftTable.split(" ").last + val sql = + s""" + |SELECT d.deptno + |FROM locations l JOIN + |${leftTable} ON l.locationid = ${leftAlias}.deptno + |${joinType} ${rightTable} + |ON e.deptno=d.deptno + |where d.deptno >= 40 OR d.deptno < 2; + |""".stripMargin + compareNotRewriteAndRows(sql, noData = true) + } + + runOuterJoinFunc($1)(1) + } + + test("clean_project_outJoin_view_1") { + def $1(joinType: String, viewNumber: Int): Unit = { + val joinName = joinType.replace(" ", "_") + spark.sql( + s""" + |DROP MATERIALIZED VIEW IF EXISTS ${joinName}_${viewNumber}; + |""".stripMargin + ) + } + + runOuterJoinFunc($1)(1) + } +} diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewRightJoinRuleSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewRightJoinRuleSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..045d3d174ff720d176ff4fa82252730ace00645a --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/MaterializedViewRightJoinRuleSuite.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer.rules + +import com.huawei.boostkit.spark.util.RewriteHelper.{disableCachePlugin, enableCachePlugin} + +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ + +class MaterializedViewRightJoinRuleSuite extends RewriteSuite { + + test("mv_right_join") { + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_right_join; + |""".stripMargin + ) + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv_right_join + |AS + |SELECT e.*,d.deptname + |FROM depts d RIGHT JOIN emps e + |ON e.deptno=d.deptno where e.deptno >= 2; + |""".stripMargin + ) + } + + test("mv_right_join_1") { + // is same to view + val sql = + """ + |SELECT e.*,d.deptname + |FROM depts d RIGHT JOIN emps e + |ON e.deptno=d.deptno where e.deptno >= 2; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_right_join", noData = false) + } + + test("mv_right_join_2") { + // view tables is subset of query + val sql = + """ + |SELECT e.*,d.deptname, l.locationid + |FROM depts d RIGHT JOIN emps e ON e.deptno=d.deptno JOIN locations l + |ON e.locationid=l.locationid where e.deptno >= 2; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_right_join", noData = false) + } + + test("mv_right_join_3") { + // view tables is subset of query + val sql = + """ + |SELECT e.*,d.deptname, l.locationid + |FROM depts d RIGHT JOIN emps e ON e.deptno=d.deptno JOIN locations l + |ON e.locationid=l.locationid where e.deptno = 5; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_right_join", noData = true) + } + + test("mv_right_join_cannot_rewrite") { + val sql = + """ + |SELECT e1.*,d.deptname,e2.* + |FROM depts d RIGHT JOIN emps e1 ON e1.deptno=d.deptno JOIN emps e2 + |on d.deptno = e2.deptno where e1.deptno >= 2; + |""".stripMargin + val df = spark.sql(sql) + val optPlan = df.queryExecution.optimizedPlan + disableCachePlugin() + val df2 = spark.sql(sql) + val srcPlan = df2.queryExecution.optimizedPlan + enableCachePlugin() + assert(optPlan.toString().replaceAll("#\\d+", "") + .equals(srcPlan.toString().replaceAll("#\\d+", ""))) + } + + test("mv_right_join_4") { + // view tables is subset of query, join with subquery + val sql = + """ + |SELECT v1.*,l.locationid + |FROM + |(SELECT e.*,d.deptname + |FROM depts d RIGHT JOIN emps e + |ON e.deptno=d.deptno where e.deptno >= 2 + |) v1 + |JOIN locations l + |ON v1.locationid=l.locationid; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_right_join", noData = false) + } + + test("mv_right_join_5") { + // view tables is same to query, equal columns + val sql = + """ + |SELECT d.deptname + |FROM depts d RIGHT JOIN emps e + |ON e.deptno=d.deptno where e.deptno >= 2; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_right_join", noData = false) + } + + test("right_join_range1") { + // where的条件范围比视图大,不能重写 + val sql = + """ + |SELECT e.*,d.deptname, l.locationid + |FROM depts d RIGHT JOIN emps e ON e.deptno=d.deptno JOIN locations l + |ON e.locationid=l.locationid where e.deptno > 0; + |""".stripMargin + val df = spark.sql(sql) + val optPlan = df.queryExecution.optimizedPlan + disableCachePlugin() + val df2 = spark.sql(sql) + val srcPlan = df2.queryExecution.optimizedPlan + enableCachePlugin() + assert(optPlan.toString().replaceAll("#\\d+", "") + .equals(srcPlan.toString().replaceAll("#\\d+", ""))) + } + + test("right_join_range2") { + // where的条件范围比视图小,可以重写 + val sql = + """ + |SELECT e.*,d.deptname, l.locationid + |FROM depts d RIGHT JOIN emps e ON e.deptno=d.deptno JOIN locations l + |ON e.locationid=l.locationid where e.deptno > 2; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_right_join", noData = true) + } + + test("clean_env") { + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_right_join; + |""".stripMargin + ) + } +} diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/OuterJoinSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/OuterJoinSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..b8d78f2484ec64306cdd076054881eb95fdb199f --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/OuterJoinSuite.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer.rules + +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ + +class OuterJoinSuite extends RewriteSuite { + + // Since FULL OUTER JOIN cannot push the predicate down, + // it cannot compensate the predicate, + // so OUTER JOIN does not include FULL OUTER JOIN. + val OUTER_JOINS = List("LEFT JOIN", "RIGHT JOIN", "SEMI JOIN", "ANTI JOIN") + + def runOuterJoinFunc(fun: (String, Int) => Unit)(viewNumber: Int): Unit = { + OUTER_JOINS.foreach { + outJoin => + fun(outJoin, viewNumber) + } + } +} diff --git a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/RewriteSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/RewriteSuite.scala similarity index 86% rename from omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/RewriteSuite.scala rename to omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/RewriteSuite.scala index 196e25f5538e4471f1e92dc98beaec739db74cc2..90acefa6410603ed5cc6f0b9875276f2a9af3f26 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/RewriteSuite.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/RewriteSuite.scala @@ -17,42 +17,43 @@ package org.apache.spark.sql.catalyst.optimizer.rules -import com.huawei.boostkit.spark.conf.OmniCachePluginConfig +import com.huawei.boostkit.spark.conf.OmniMVPluginConfig import com.huawei.boostkit.spark.util.RewriteHelper._ import java.io.File import java.util.Locale +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} +import org.scalatest.funsuite.AnyFunSuite -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite.spark import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, QueryPlan} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{sideBySide, toPrettySQL} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.types.StringType -class RewriteSuite extends SparkFunSuite with PredicateHelper { + +class RewriteSuite extends AnyFunSuite + with BeforeAndAfterAll + with BeforeAndAfterEach + with PredicateHelper { System.setProperty("HADOOP_USER_NAME", "root") - lazy val spark: SparkSession = SparkSession.builder().master("local") - .config("spark.sql.extensions", "com.huawei.boostkit.spark.OmniCache") - .config("hive.exec.dynamic.partition.mode", "nonstrict") - .config("spark.ui.port", "4050") - // .config("spark.sql.planChangeLog.level","WARN") - .config("spark.sql.omnicache.logLevel", "WARN") - .enableHiveSupport() - .getOrCreate() - spark.sparkContext.setLogLevel("WARN") - lazy val catalog: SessionCatalog = spark.sessionState.catalog + override def beforeEach(): Unit = { enableCachePlugin() } + override def beforeAll(): Unit = { + preCreateTable() + } + def preDropTable(): Unit = { if (File.separatorChar == '\\') { return @@ -64,8 +65,10 @@ class RewriteSuite extends SparkFunSuite with PredicateHelper { } def preCreateTable(): Unit = { + disableCachePlugin() preDropTable() - if (catalog.tableExists(TableIdentifier("locations"))) { + if (RewriteSuite.catalog.tableExists(TableIdentifier("locations"))) { + enableCachePlugin() return } spark.sql( @@ -105,6 +108,16 @@ class RewriteSuite extends SparkFunSuite with PredicateHelper { |INSERT INTO TABLE depts VALUES(2,'deptname2'); |""".stripMargin ) + spark.sql( + """ + |INSERT INTO TABLE depts VALUES(3,'deptname3'); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE depts VALUES(4,'deptname4'); + |""".stripMargin + ) spark.sql( """ @@ -128,6 +141,12 @@ class RewriteSuite extends SparkFunSuite with PredicateHelper { |""".stripMargin ) + spark.sql( + """ + |INSERT INTO TABLE emps VALUES(3,null,3,'empname3',3.0); + |""".stripMargin + ) + spark.sql( """ |CREATE TABLE IF NOT EXISTS column_type( @@ -238,9 +257,30 @@ class RewriteSuite extends SparkFunSuite with PredicateHelper { |); |""".stripMargin ) + enableCachePlugin() } +} + +object RewriteSuite extends AnyFunSuite + with BeforeAndAfterAll + with BeforeAndAfterEach + with PredicateHelper { - preCreateTable() + val spark: SparkSession = SparkSession.builder().master("local") + .config("spark.sql.extensions", "com.huawei.boostkit.spark.OmniMV") + .config("hive.exec.dynamic.partition.mode", "nonstrict") + .config("spark.ui.port", "4050") + // .config("spark.sql.planChangeLog.level", "WARN") + .config("spark.sql.omnimv.logLevel", "WARN") + .config("spark.sql.omnimv.dbs", "default") + .config("spark.sql.omnimv.metadata.initbyquery.enable", "false") + .config("hive.in.test", "true") + .config("spark.sql.omnimv.metadata.path", "./user/omnimv/metadata") + .config("spark.sql.omnimv.washout.automatic.enable", "false") + .enableHiveSupport() + .getOrCreate() + spark.sparkContext.setLogLevel("WARN") + lazy val catalog: SessionCatalog = spark.sessionState.catalog def transformAllExpressions(plan: LogicalPlan, rule: PartialFunction[Expression, Expression]): LogicalPlan = { @@ -450,19 +490,29 @@ class RewriteSuite extends SparkFunSuite with PredicateHelper { } } - def isRewritedByMV(database: String, mv: String, logicalPlan: LogicalPlan): Boolean = { - logicalPlan.foreachUp { + def isRewritedByMV(database: String, mvSrc: String, logicalPlan: LogicalPlan): Boolean = { + val mv = mvSrc.toLowerCase(Locale.ROOT) + logicalPlan.foreach { case _@HiveTableRelation(tableMeta, _, _, _, _) => - if (tableMeta.database == database && tableMeta.identifier.table == mv) { + if (tableMeta.database == database && tableMeta.identifier.table.contains(mv)) { return true } case _@LogicalRelation(_, _, catalogTable, _) => if (catalogTable.isDefined) { - if (catalogTable.get.database == database && catalogTable.get.identifier.table == mv) { + if (catalogTable.get.database == database && catalogTable.get.identifier + .table.contains(mv)) { return true } } - case _ => + case p => + p.transformAllExpressions { + case s: SubqueryExpression => + if (isRewritedByMV(database, mv, s.plan)) { + return true + } + s + case e => e + } } false } @@ -480,23 +530,32 @@ class RewriteSuite extends SparkFunSuite with PredicateHelper { val (rewritePlan, rewriteRows) = getPlanAndRows(sql) // 2.compare plan - assert(isRewritedByMV(database, mv, rewritePlan)) + val isRewrited = isRewritedByMV(database, mv, rewritePlan) + if (!isRewrited) { + logWarning(s"sql $sql; logicalPlan $rewritePlan is not rewritedByMV $mv") + } + assert(isRewrited) + + if (noData) { + return + } // 3.compare row disableCachePlugin() val expectedRows = getRows(sql) compareRows(rewriteRows, expectedRows, noData) + enableCachePlugin() } def isNotRewritedByMV(logicalPlan: LogicalPlan): Boolean = { logicalPlan.foreachUp { case h@HiveTableRelation(tableMeta, _, _, _, _) => - if (OmniCachePluginConfig.isMV(tableMeta)) { + if (OmniMVPluginConfig.isMV(tableMeta)) { return false } case h@LogicalRelation(_, _, catalogTable, _) => if (catalogTable.isDefined) { - if (OmniCachePluginConfig.isMV(catalogTable.get)) { + if (OmniMVPluginConfig.isMV(catalogTable.get)) { return false } } @@ -516,5 +575,6 @@ class RewriteSuite extends SparkFunSuite with PredicateHelper { disableCachePlugin() val expectedRows = getRows(sql) compareRows(rewriteRows, expectedRows, noData) + enableCachePlugin() } } diff --git a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/TpcdsSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/TpcdsNativeSuite.scala similarity index 33% rename from omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/TpcdsSuite.scala rename to omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/TpcdsNativeSuite.scala index 42adf96cce46b18647958a63928638673a00fa05..af1cc300de81308ed04baaff78d21f96345baefb 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/TpcdsSuite.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/TpcdsNativeSuite.scala @@ -17,79 +17,41 @@ package org.apache.spark.sql.catalyst.optimizer.rules -import org.apache.commons.io.IOUtils -import scala.collection.mutable +import org.scalatest.funsuite.AnyFunSuite +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier - -class TpcdsSuite extends RewriteSuite { +import org.apache.spark.sql.catalyst.catalog.SessionCatalog + + +class TpcdsNativeSuite extends AnyFunSuite { + lazy val spark_native: SparkSession = SparkSession.builder().master("local") + .config("hive.exec.dynamic.partition.mode", "nonstrict") + .config("spark.ui.port", "4051") + // .config("spark.sql.planChangeLog.level", "WARN") + .config("spark.sql.omnimv.logLevel", "WARN") + .enableHiveSupport() + .getOrCreate() + lazy val catalog: SessionCatalog = spark_native.sessionState.catalog + createTable() def createTable(): Unit = { if (catalog.tableExists(TableIdentifier("store_sales"))) { return } - val fis = this.getClass.getResourceAsStream("/tpcds_ddl.sql") - val lines = IOUtils.readLines(fis, "UTF-8") - IOUtils.closeQuietly(fis) - - var sqls = Seq.empty[String] - val sql = mutable.StringBuilder.newBuilder - lines.forEach { line => - sql.append(line) - sql.append(" ") - if (line.contains(';')) { - sqls +:= sql.toString() - sql.clear() - } - } - sqls.foreach { sql => - spark.sql(sql) - } + val ddls = TpcdsUtils.getResource("/", "tpcds_ddl.sql").split(';') + ddls.foreach(ddl => spark_native.sql(ddl)) } - createTable() - - test("subQuery outReference") { - spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv536") - spark.sql( - """ - |CREATE MATERIALIZED VIEW IF NOT EXISTS mv536 PARTITIONED BY (ws_sold_date_sk) AS - | SELECT - | web_sales.ws_ext_discount_amt, - | item.i_item_sk, - | web_sales.ws_sold_date_sk, - | web_sales.ws_item_sk, - | item.i_manufact_id - |FROM - | web_sales, - | item - |WHERE - | item.i_manufact_id = 350 - | AND web_sales.ws_item_sk = item.i_item_sk - |distribute by ws_sold_date_sk; - |""".stripMargin - ) - val sql = - """ - |SELECT sum(ws_ext_discount_amt) AS `Excess Discount Amount ` - |FROM web_sales, item, date_dim - |WHERE i_manufact_id = 350 - | AND i_item_sk = ws_item_sk - | AND d_date BETWEEN '2000-01-27' AND (cast('2000-01-27' AS DATE) + INTERVAL 90 days) - | AND d_date_sk = ws_sold_date_sk - | AND ws_ext_discount_amt > - | ( - | SELECT 1.3 * avg(ws_ext_discount_amt) - | FROM web_sales, date_dim - | WHERE ws_item_sk = i_item_sk - | AND d_date BETWEEN '2000-01-27' AND (cast('2000-01-27' AS DATE) + INTERVAL 90 days) - | AND d_date_sk = ws_sold_date_sk - | ) - |ORDER BY sum(ws_ext_discount_amt) - |LIMIT 100 - | - |""".stripMargin - compareNotRewriteAndRows(sql, noData = true) - spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv536") + /** + * Debug and run native tpcds sql + * sqlNum: tpcds sql's number + */ + test("Run the native tpcds sql") { + val sqlNum = 72 + val sql = TpcdsUtils.getResource("/tpcds", s"q${sqlNum}.sql") + val df = spark_native.sql(sql) + val qe = df.queryExecution + df.explain() } } diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/TpcdsSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/TpcdsSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..e605cbc8d7a23d43fbe95d8ff755f97c039b33a9 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/rules/TpcdsSuite.scala @@ -0,0 +1,561 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer.rules + +import com.huawei.boostkit.spark.util.ViewMetadata +import java.util +import org.apache.commons.io.IOUtils +import org.apache.hadoop.fs.Path +import scala.collection.mutable +import scala.io.Source + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ + +class TpcdsSuite extends RewriteSuite { + + def createTable(): Unit = { + if (RewriteSuite.catalog.tableExists(TableIdentifier("store_sales"))) { + return + } + val fis = this.getClass.getResourceAsStream("/tpcds_ddl.sql") + val lines = IOUtils.readLines(fis, "UTF-8") + IOUtils.closeQuietly(fis) + + var sqls = Seq.empty[String] + val sql = mutable.StringBuilder.newBuilder + lines.forEach { line => + sql.append(line) + sql.append(" ") + if (line.contains(';')) { + sqls +:= sql.toString() + sql.clear() + } + } + sqls.foreach { sql => + spark.sql(sql) + } + } + + createTable() + + test("subQuery outReference") { + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv536") + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv536 PARTITIONED BY (ws_sold_date_sk) AS + | SELECT + | web_sales.ws_ext_discount_amt, + | item.i_item_sk, + | web_sales.ws_sold_date_sk, + | web_sales.ws_item_sk, + | item.i_manufact_id + |FROM + | web_sales, + | item + |WHERE + | item.i_manufact_id = 350 + | AND web_sales.ws_item_sk = item.i_item_sk + |distribute by ws_sold_date_sk; + |""".stripMargin + ) + val sql = + """ + |SELECT sum(ws_ext_discount_amt) AS `Excess Discount Amount ` + |FROM web_sales, item, date_dim + |WHERE i_manufact_id = 350 + | AND i_item_sk = ws_item_sk + | AND d_date BETWEEN '2000-01-27' AND (cast('2000-01-27' AS DATE) + INTERVAL 90 days) + | AND d_date_sk = ws_sold_date_sk + | AND ws_ext_discount_amt > + | ( + | SELECT 1.3 * avg(ws_ext_discount_amt) + | FROM web_sales, date_dim + | WHERE ws_item_sk = i_item_sk + | AND d_date BETWEEN '2000-01-27' AND (cast('2000-01-27' AS DATE) + INTERVAL 90 days) + | AND d_date_sk = ws_sold_date_sk + | ) + |ORDER BY sum(ws_ext_discount_amt) + |LIMIT 100 + | + |""".stripMargin + RewriteSuite.comparePlansAndRows(sql, "default", "mv536", noData = true) + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv536") + } + + test("sum decimal") { + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv_q11") + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv_q11 AS + | SELECT + | c_customer_id customer_id, + | c_first_name customer_first_name, + | c_last_name customer_last_name, + | c_preferred_cust_flag customer_preferred_cust_flag, + | c_birth_country customer_birth_country, + | c_login customer_login, + | c_email_address customer_email_address, + | d_year dyear, + | sum(ss_ext_list_price - ss_ext_discount_amt) year_total, + | 's' sale_type + | FROM customer, store_sales, date_dim + | WHERE c_customer_sk = ss_customer_sk + | AND ss_sold_date_sk = d_date_sk + | GROUP BY c_customer_id + | , c_first_name + | , c_last_name + | , d_year + | , c_preferred_cust_flag + | , c_birth_country + | , c_login + | , c_email_address + | , d_year + | , c_customer_sk + |""".stripMargin + ) + val sql = + """ + |WITH year_total AS ( + | SELECT + | c_customer_id customer_id, + | c_first_name customer_first_name, + | c_last_name customer_last_name, + | c_preferred_cust_flag customer_preferred_cust_flag, + | c_birth_country customer_birth_country, + | c_login customer_login, + | c_email_address customer_email_address, + | d_year dyear, + | sum(ss_ext_list_price - ss_ext_discount_amt) year_total, + | 's' sale_type + | FROM customer, store_sales, date_dim + | WHERE c_customer_sk = ss_customer_sk + | AND ss_sold_date_sk = d_date_sk + | GROUP BY c_customer_id + | , c_first_name + | , c_last_name + | , d_year + | , c_preferred_cust_flag + | , c_birth_country + | , c_login + | , c_email_address + | , d_year + | UNION ALL + | SELECT + | c_customer_id customer_id, + | c_first_name customer_first_name, + | c_last_name customer_last_name, + | c_preferred_cust_flag customer_preferred_cust_flag, + | c_birth_country customer_birth_country, + | c_login customer_login, + | c_email_address customer_email_address, + | d_year dyear, + | sum(ws_ext_list_price - ws_ext_discount_amt) year_total, + | 'w' sale_type + | FROM customer, web_sales, date_dim + | WHERE c_customer_sk = ws_bill_customer_sk + | AND ws_sold_date_sk = d_date_sk + | GROUP BY + | c_customer_id, c_first_name, c_last_name, c_preferred_cust_flag, c_birth_country, + | c_login, c_email_address, d_year) + |SELECT t_s_secyear.customer_preferred_cust_flag + |FROM year_total t_s_firstyear + | , year_total t_s_secyear + | , year_total t_w_firstyear + | , year_total t_w_secyear + |WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + | AND t_s_firstyear.customer_id = t_w_secyear.customer_id + | AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + | AND t_s_firstyear.sale_type = 's' + | AND t_w_firstyear.sale_type = 'w' + | AND t_s_secyear.sale_type = 's' + | AND t_w_secyear.sale_type = 'w' + | AND t_s_firstyear.dyear = 2001 + | AND t_s_secyear.dyear = 2001 + 1 + | AND t_w_firstyear.dyear = 2001 + | AND t_w_secyear.dyear = 2001 + 1 + | AND t_s_firstyear.year_total > 0 + | AND t_w_firstyear.year_total > 0 + | AND CASE WHEN t_w_firstyear.year_total > 0 + | THEN t_w_secyear.year_total / t_w_firstyear.year_total + | ELSE NULL END + | > CASE WHEN t_s_firstyear.year_total > 0 + | THEN t_s_secyear.year_total / t_s_firstyear.year_total + | ELSE NULL END + |ORDER BY t_s_secyear.customer_preferred_cust_flag + |LIMIT 100 + | + |""".stripMargin + RewriteSuite.comparePlansAndRows(sql, "default", "mv_q11", noData = true) + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv_q11") + } + test("resort") { + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv103") + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv103 + |PARTITIONED BY (ss_sold_date_sk) + |AS + |SELECT + | item.i_item_id, + | store_sales.ss_ext_discount_amt, + | store_sales.ss_quantity, + | item.i_item_desc, + | item.i_product_name, + | item.i_manufact_id, + | store_sales.ss_sold_date_sk, + | item.i_brand_id, + | item.i_item_sk, + | date_dim.d_moy, + | item.i_category, + | store_sales.ss_item_sk, + | item.i_brand, + | date_dim.d_date, + | date_dim.d_month_seq, + | item.i_wholesale_cost, + | date_dim.d_dom, + | store_sales.ss_net_paid, + | store_sales.ss_addr_sk, + | item.i_color, + | store_sales.ss_store_sk, + | store_sales.ss_cdemo_sk, + | store_sales.ss_list_price, + | store_sales.ss_wholesale_cost, + | store_sales.ss_ticket_number, + | date_dim.d_year, + | store_sales.ss_hdemo_sk, + | store_sales.ss_customer_sk, + | item.i_manufact, + | store_sales.ss_sales_price, + | item.i_current_price, + | item.i_class, + | store_sales.ss_ext_list_price, + | date_dim.d_quarter_name, + | item.i_units, + | item.i_manager_id, + | date_dim.d_day_name, + | store_sales.ss_coupon_amt, + | item.i_category_id, + | store_sales.ss_promo_sk, + | store_sales.ss_net_profit, + | date_dim.d_qoy, + | date_dim.d_week_seq, + | store_sales.ss_ext_sales_price, + | item.i_size, + | store_sales.ss_sold_time_sk, + | item.i_class_id, + | date_dim.d_dow, + | store_sales.ss_ext_wholesale_cost, + | store_sales.ss_ext_tax, + | date_dim.d_date_sk + |FROM + | date_dim, + | item, + | store_sales + |WHERE + | store_sales.ss_item_sk = item.i_item_sk + | AND date_dim.d_date_sk = store_sales.ss_sold_date_sk + | AND (item.i_manager_id = 8 OR item.i_manager_id = 1 OR item.i_manager_id = 28) + | AND (date_dim.d_year = 1998 OR date_dim.d_year = 2000 OR date_dim.d_year = 1999) + | AND date_dim.d_moy = 11 + |DISTRIBUTE BY ss_sold_date_sk; + |""".stripMargin + ) + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv9") + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv9 + |AS + |SELECT + | date_dim.d_year, + | item.i_category, + | item.i_item_id, + | item.i_class, + | item.i_current_price, + | item.i_item_desc, + | item.i_brand, + | date_dim.d_date, + | item.i_manufact_id, + | item.i_manager_id, + | item.i_brand_id, + | item.i_category_id, + | date_dim.d_moy, + | item.i_item_sk, + | sum(store_sales.ss_ext_sales_price) AS AGG0, + | count(1) AS AGG1 + |FROM + | date_dim, + | item, + | store_sales + |WHERE + | store_sales.ss_item_sk = item.i_item_sk + | AND store_sales.ss_sold_date_sk = date_dim.d_date_sk + |GROUP BY + | date_dim.d_year, + | item.i_category, + | item.i_item_id, + | item.i_class, + | item.i_current_price, + | item.i_item_desc, + | item.i_brand, + | date_dim.d_date, + | item.i_manufact_id, + | item.i_manager_id, + | item.i_brand_id, + | item.i_category_id, + | date_dim.d_moy, + | item.i_item_sk; + |""".stripMargin + ) + val os = ViewMetadata.fs.create(new Path(ViewMetadata.metadataPriorityPath, "mv103_9")) + val list = new util.ArrayList[String]() + list.add("default.mv9,default.mv103") + IOUtils.writeLines(list, "\n", os) + os.close() + ViewMetadata.loadViewPriorityFromFile() + val sql = + """ + |SELECT + | dt.d_year, + | item.i_category_id, + | item.i_category, + | sum(ss_ext_sales_price) + |FROM date_dim dt, store_sales, item + |WHERE dt.d_date_sk = store_sales.ss_sold_date_sk + | AND store_sales.ss_item_sk = item.i_item_sk + | AND item.i_manager_id = 1 + | AND dt.d_moy = 11 + | AND dt.d_year = 2000 + |GROUP BY dt.d_year + | , item.i_category_id + | , item.i_category + |ORDER BY sum(ss_ext_sales_price) DESC, dt.d_year + | , item.i_category_id + | , item.i_category + |LIMIT 100 + | + |""".stripMargin + spark.sql(sql).explain() + RewriteSuite.comparePlansAndRows(sql, "default", "mv9", noData = true) + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv103") + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv9") + } + + test("resort2") { + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv103") + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv103 + |PARTITIONED BY (ss_sold_date_sk) + |AS + |SELECT + | item.i_item_id, + | store_sales.ss_ext_discount_amt, + | store_sales.ss_quantity, + | item.i_item_desc, + | item.i_product_name, + | item.i_manufact_id, + | store_sales.ss_sold_date_sk, + | item.i_brand_id, + | item.i_item_sk, + | date_dim.d_moy, + | item.i_category, + | store_sales.ss_item_sk, + | item.i_brand, + | date_dim.d_date, + | date_dim.d_month_seq, + | item.i_wholesale_cost, + | date_dim.d_dom, + | store_sales.ss_net_paid, + | store_sales.ss_addr_sk, + | item.i_color, + | store_sales.ss_store_sk, + | store_sales.ss_cdemo_sk, + | store_sales.ss_list_price, + | store_sales.ss_wholesale_cost, + | store_sales.ss_ticket_number, + | date_dim.d_year, + | store_sales.ss_hdemo_sk, + | store_sales.ss_customer_sk, + | item.i_manufact, + | store_sales.ss_sales_price, + | item.i_current_price, + | item.i_class, + | store_sales.ss_ext_list_price, + | date_dim.d_quarter_name, + | item.i_units, + | item.i_manager_id, + | date_dim.d_day_name, + | store_sales.ss_coupon_amt, + | item.i_category_id, + | store_sales.ss_promo_sk, + | store_sales.ss_net_profit, + | date_dim.d_qoy, + | date_dim.d_week_seq, + | store_sales.ss_ext_sales_price, + | item.i_size, + | store_sales.ss_sold_time_sk, + | item.i_class_id, + | date_dim.d_dow, + | store_sales.ss_ext_wholesale_cost, + | store_sales.ss_ext_tax, + | date_dim.d_date_sk + |FROM + | date_dim, + | item, + | store_sales + |WHERE + | store_sales.ss_item_sk = item.i_item_sk + | AND date_dim.d_date_sk = store_sales.ss_sold_date_sk + | AND (item.i_manager_id = 8 OR item.i_manager_id = 1 OR item.i_manager_id = 28) + | AND (date_dim.d_year = 1998 OR date_dim.d_year = 2000 OR date_dim.d_year = 1999) + | AND date_dim.d_moy = 11 + |DISTRIBUTE BY ss_sold_date_sk; + |""".stripMargin + ) + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv9") + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv9 + |AS + |SELECT + | date_dim.d_year, + | item.i_category, + | item.i_item_id, + | item.i_class, + | item.i_current_price, + | item.i_item_desc, + | item.i_brand, + | date_dim.d_date, + | item.i_manufact_id, + | item.i_manager_id, + | item.i_brand_id, + | item.i_category_id, + | date_dim.d_moy, + | item.i_item_sk, + | sum(store_sales.ss_ext_sales_price) AS AGG0, + | count(1) AS AGG1 + |FROM + | date_dim, + | item, + | store_sales + |WHERE + | store_sales.ss_item_sk = item.i_item_sk + | AND store_sales.ss_sold_date_sk = date_dim.d_date_sk + |GROUP BY + | date_dim.d_year, + | item.i_category, + | item.i_item_id, + | item.i_class, + | item.i_current_price, + | item.i_item_desc, + | item.i_brand, + | date_dim.d_date, + | item.i_manufact_id, + | item.i_manager_id, + | item.i_brand_id, + | item.i_category_id, + | date_dim.d_moy, + | item.i_item_sk; + |""".stripMargin + ) + val os = ViewMetadata.fs.create(new Path(ViewMetadata.metadataPriorityPath, "mv103_9")) + val list = new util.ArrayList[String]() + list.add("default.mv103,default.mv9") + IOUtils.writeLines(list, "\n", os) + os.close() + ViewMetadata.loadViewPriorityFromFile() + val sql = + """ + |SELECT + | dt.d_year, + | item.i_category_id, + | item.i_category, + | sum(ss_ext_sales_price) + |FROM date_dim dt, store_sales, item + |WHERE dt.d_date_sk = store_sales.ss_sold_date_sk + | AND store_sales.ss_item_sk = item.i_item_sk + | AND item.i_manager_id = 1 + | AND dt.d_moy = 11 + | AND dt.d_year = 2000 + |GROUP BY dt.d_year + | , item.i_category_id + | , item.i_category + |ORDER BY sum(ss_ext_sales_price) DESC, dt.d_year + | , item.i_category_id + | , item.i_category + |LIMIT 100 + | + |""".stripMargin + spark.sql(sql).explain() + RewriteSuite.comparePlansAndRows(sql, "default", "mv103", noData = true) + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv103") + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv9") + } + + test("subQuery condition 01") { + spark.sql("DROP MATERIALIZED VIEW IF EXISTS sc01") + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS sc01 AS + |SELECT * + |FROM catalog_sales t1 + | LEFT JOIN (select * from inventory where inv_item_sk > 100 or + | inv_date_sk < 40) t2 ON (cs_item_sk = t2.inv_item_sk) + | LEFT JOIN warehouse t3 ON (t3.w_warehouse_sk = t2.inv_warehouse_sk) + | Join item t4 ON (t4.i_item_sk = t1.cs_item_sk) + |WHERE t2.inv_quantity_on_hand < t1.cs_quantity; + |""".stripMargin + ) + val sql = + """ + |SELECT + | i_item_desc, + | w_warehouse_name, + | count(CASE WHEN p_promo_sk IS NULL + | THEN 1 + | ELSE 0 END) promo, + | count(*) total_cnt + |FROM catalog_sales t1 + | LEFT JOIN (select * from inventory where inv_item_sk > 100) t2 + | ON (cs_item_sk = t2.inv_item_sk) + | LEFT JOIN warehouse t3 ON (t3.w_warehouse_sk = t2.inv_warehouse_sk) + | Join item t4 ON (t4.i_item_sk = t1.cs_item_sk) + | LEFT JOIN promotion ON (cs_item_sk = p_promo_sk) + |WHERE t2.inv_quantity_on_hand < t1.cs_quantity + |GROUP BY i_item_desc, w_warehouse_name; + |""".stripMargin + RewriteSuite.comparePlansAndRows(sql, "default", "sc01", noData = true) + spark.sql("DROP MATERIALIZED VIEW IF EXISTS sc01") + } +} + +object TpcdsUtils { + /** + * Obtain the contents of the resource file + * + * @param path If the path of the file relative to reousrce is "/tpcds", enter "/tpcds". + * @param fileName If the file name is q14.sql, enter q14.sql here + * @return + */ + def getResource(path: String = "/", fileName: String): String = { + val filePath = s"${this.getClass.getResource(path).getPath}/${fileName}" + Source.fromFile(filePath).mkString + } +} + diff --git a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyAndOrSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyAndOrSuite.scala similarity index 98% rename from omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyAndOrSuite.scala rename to omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyAndOrSuite.scala index d7b497596b0c93ecd55755a77001dd932cf34d80..26af163e4ed13dad1af540764e44e1b63236c6c4 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyAndOrSuite.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyAndOrSuite.scala @@ -21,6 +21,7 @@ import com.huawei.boostkit.spark.util.ExprSimplifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} diff --git a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyAndSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyAndSuite.scala similarity index 97% rename from omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyAndSuite.scala rename to omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyAndSuite.scala index b6b9cef9d6e5ff8f8dba18e3da5d1c0679be4db8..139145dc91ceaa68de19443b1a73ced118cbbcb2 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyAndSuite.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyAndSuite.scala @@ -21,6 +21,7 @@ import com.huawei.boostkit.spark.util.ExprSimplifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} @@ -681,6 +682,21 @@ class SimplifyAndSuite extends RewriteSuite { assert(res.sql.equals("(spark_catalog.default.t1.`ID` = 5)")) } + test("simplify_simplifyAndEqualTo") { + val df = spark.sql( + """ + |SELECT * FROM T1 + |WHERE ID = 5 AND ID = 5 AND ID = 5; + |""".stripMargin + ) + val targetCondition = df.queryExecution.analyzed + // set unknownAsFalse = true + val simplify = ExprSimplifier(unknownAsFalse = true, pulledUpPredicates) + val res = simplify.simplify(targetCondition + .asInstanceOf[Project].child.asInstanceOf[Filter].condition) + assert(res.sql.equals("(spark_catalog.default.t1.`ID` = 5)")) + } + test("clean env") { // clean spark.sql( diff --git a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyCaseSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyCaseSuite.scala similarity index 99% rename from omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyCaseSuite.scala rename to omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyCaseSuite.scala index 4497d31fdf6f16abce4310102b181830c2080e01..a05e1aba78ba4d17b4ea6659c9136bb34d9361f9 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyCaseSuite.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyCaseSuite.scala @@ -21,6 +21,7 @@ import com.huawei.boostkit.spark.util.ExprSimplifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ import org.apache.spark.sql.catalyst.plans.logical.Project class SimplifyCaseSuite extends RewriteSuite { diff --git a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyComparisonSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyComparisonSuite.scala similarity index 99% rename from omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyComparisonSuite.scala rename to omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyComparisonSuite.scala index 42a423dd26f639f50db3b3a2202d653c069a0179..e638ef5d4aff51f963f2cc627af822c6ad0d100e 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyComparisonSuite.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyComparisonSuite.scala @@ -21,6 +21,7 @@ import com.huawei.boostkit.spark.util.ExprSimplifier import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, Expression, GreaterThan, LessThan, LessThanOrEqual, Literal} import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} diff --git a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyNotSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyNotSuite.scala similarity index 98% rename from omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyNotSuite.scala rename to omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyNotSuite.scala index 440d292f5fc017d0ddfa380b3a89b1da3c9974ae..58e33fb562a972c39ea629cb606b32e90b80e007 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyNotSuite.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyNotSuite.scala @@ -21,6 +21,7 @@ import com.huawei.boostkit.spark.util.ExprSimplifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} diff --git a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyOrSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyOrSuite.scala similarity index 96% rename from omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyOrSuite.scala rename to omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyOrSuite.scala index 00ccca4194b1745128c2b52962888786c5bb6839..61872a2736cc6a9c076817e4aacf36ab2f97d586 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyOrSuite.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/optimizer/simplify/SimplifyOrSuite.scala @@ -21,6 +21,7 @@ import com.huawei.boostkit.spark.util.ExprSimplifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/parser/NativeSqlParseSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/parser/NativeSqlParseSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..184aa4a1b9a1afae6c54d561cbcadefc4bb3fb4b --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/parser/NativeSqlParseSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ +import org.apache.spark.sql.hive.execution.{CreateHiveTableAsSelectCommand, InsertIntoHiveTable} + +class NativeSqlParseSuite extends RewriteSuite { + + test("create table xxx as select xxx") { + spark.sql( + """ + |drop table if exists insert_select1; + |""".stripMargin) + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_insert_select1; + |""".stripMargin) + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv_insert_select1 + |AS + |select locationid, state from locations where locationid = 1; + |""".stripMargin) + val df = spark.sql( + """ + |create table insert_select1 + |as select locationid, state from locations where locationid = 1; + |""".stripMargin) + val optPlan = df.queryExecution.optimizedPlan + assert(optPlan.isInstanceOf[CreateHiveTableAsSelectCommand]) + assert(isRewritedByMV("default", "mv_insert_select1", + optPlan.asInstanceOf[CreateHiveTableAsSelectCommand].query)) + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_insert_select1; + |""".stripMargin) + } + + test("insert overwrite xxx select xxx") { + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv_insert_select2 + |AS + |select locationid, state from locations where locationid = 2; + |""".stripMargin) + val df = spark.sql( + """ + |insert overwrite insert_select1 + |select locationid, state from locations where locationid = 2; + |""".stripMargin) + val optPlan = df.queryExecution.optimizedPlan + assert(optPlan.isInstanceOf[InsertIntoHiveTable]) + assert(isRewritedByMV("default", "mv_insert_select2", + optPlan.asInstanceOf[InsertIntoHiveTable].query)) + } + + test("insert into xxx select xxx") { + val df = spark.sql( + """ + |insert into insert_select1 + |select locationid, state from locations where locationid = 2; + |""".stripMargin) + val optPlan = df.queryExecution.optimizedPlan + assert(optPlan.isInstanceOf[InsertIntoHiveTable]) + assert(isRewritedByMV("default", "mv_insert_select2", + optPlan.asInstanceOf[InsertIntoHiveTable].query)) + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_insert_select2; + |""".stripMargin) + } + + test("clean") { + spark.sql( + """ + |drop table if exists insert_select1; + |""".stripMargin) + } + +} diff --git a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlParserSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlParserSuite.scala similarity index 88% rename from omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlParserSuite.scala rename to omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlParserSuite.scala index 273197c91f3a572c246f645bfc7411a6caa51c68..d7a7cd1dfc7c6e50aedbfc734740c72802ce361b 100644 --- a/omnicache/omnicache-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlParserSuite.scala +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlParserSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.catalyst.parser -import com.huawei.boostkit.spark.conf.OmniCachePluginConfig -import com.huawei.boostkit.spark.conf.OmniCachePluginConfig._ +import com.huawei.boostkit.spark.conf.OmniMVPluginConfig +import com.huawei.boostkit.spark.conf.OmniMVPluginConfig._ import com.huawei.boostkit.spark.util.{RewriteHelper, ViewMetadata} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite._ class SqlParserSuite extends RewriteSuite { @@ -402,7 +403,7 @@ class SqlParserSuite extends RewriteSuite { |""".stripMargin.replaceAll("^[\r\n]+", "") val sql1 = sql.replaceAll("[\r\n]", "").trim - val sql2 = sql.substring(0, OmniCachePluginConfig.getConf.showMVQuerySqlLen) + val sql2 = sql.substring(0, OmniMVPluginConfig.getConf.showMVQuerySqlLen) .replaceAll("[\r\n]", "").trim assert { @@ -459,7 +460,7 @@ class SqlParserSuite extends RewriteSuite { .properties(MV_REWRITE_ENABLED).toBoolean ) assert( - !ViewMetadata.isViewExists(table.quotedString) + !ViewMetadata.isViewExists(ViewMetadata.formatViewName(table)) ) spark.sql( """ @@ -472,7 +473,7 @@ class SqlParserSuite extends RewriteSuite { .properties(MV_REWRITE_ENABLED).toBoolean ) assert( - ViewMetadata.isViewExists(table.quotedString) + ViewMetadata.isViewExists(ViewMetadata.formatViewName(table)) ) } @@ -495,7 +496,7 @@ class SqlParserSuite extends RewriteSuite { .properties(MV_REWRITE_ENABLED).toBoolean ) assert( - ViewMetadata.isViewExists(table.quotedString) + ViewMetadata.isViewExists(ViewMetadata.formatViewName(table)) ) spark.sql( """ @@ -508,7 +509,7 @@ class SqlParserSuite extends RewriteSuite { .properties(MV_REWRITE_ENABLED).toBoolean ) assert( - !ViewMetadata.isViewExists(table.quotedString) + !ViewMetadata.isViewExists(ViewMetadata.formatViewName(table)) ) } @@ -531,7 +532,7 @@ class SqlParserSuite extends RewriteSuite { .properties(MV_REWRITE_ENABLED).toBoolean ) assert( - !ViewMetadata.isViewExists(table.quotedString) + !ViewMetadata.isViewExists(ViewMetadata.formatViewName(table)) ) spark.sql( """ @@ -544,7 +545,7 @@ class SqlParserSuite extends RewriteSuite { .properties(MV_REWRITE_ENABLED).toBoolean ) assert( - ViewMetadata.isViewExists(table.quotedString) + ViewMetadata.isViewExists(ViewMetadata.formatViewName(table)) ) } @@ -708,4 +709,75 @@ class SqlParserSuite extends RewriteSuite { spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv_create_agg1;") spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv_create_agg2;") } + + test("mv_auto_update1") { + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv_auto_update1;") + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv_auto_update1 + |AS + |SELECT * FROM emps; + |""".stripMargin) + + val uri = spark.sessionState.catalog + .getTableMetadata(TableIdentifier("mv_auto_update1")) + .storage.locationUri.get + val lastTime = ViewMetadata.getPathTime(uri) + + spark.sql( + """ + |INSERT INTO TABLE emps VALUES(1,1,1,'empname1',1.0); + |""".stripMargin + ) + + val sql = "SELECT * FROM emps;" + comparePlansAndRows(sql, "default", "mv_auto_update1", noData = false) + val nowTime = ViewMetadata.getPathTime(uri) + assert(nowTime > lastTime) + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv_auto_update1;") + } + + test("mv_auto_update2") { + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS mv_auto_update2; + |""".stripMargin + ) + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS mv_auto_update2 + |AS + |SELECT e.*,d.deptname + |FROM emps e JOIN depts d + |ON e.deptno=d.deptno; + |""".stripMargin + ) + + val uri = spark.sessionState.catalog + .getTableMetadata(TableIdentifier("mv_auto_update2")) + .storage.locationUri.get + val lastTime = ViewMetadata.getPathTime(uri) + + spark.sql( + """ + |INSERT INTO TABLE emps VALUES(1,1,1,'empname1',1.0); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE depts VALUES(1,'deptname1'); + |""".stripMargin + ) + + val sql = + """ + |SELECT e.*,d.deptname,l.locationid + |FROM emps e JOIN depts d JOIN locations l + |ON e.deptno=d.deptno AND e.locationid=l.locationid; + |""".stripMargin + comparePlansAndRows(sql, "default", "mv_auto_update2", noData = false) + val nowTime = ViewMetadata.getPathTime(uri) + assert(nowTime > lastTime) + spark.sql("DROP MATERIALIZED VIEW IF EXISTS mv_auto_update2;") + } } diff --git a/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/parser/WashOutMVSuite.scala b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/parser/WashOutMVSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..ad712170952bc00d77eb073fc324bedf0fe178f4 --- /dev/null +++ b/omnimv/omnimv-spark-extension/plugin/src/test/scala/org/apache/spark/sql/catalyst/parser/WashOutMVSuite.scala @@ -0,0 +1,546 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser + +import com.huawei.boostkit.spark.conf.OmniMVPluginConfig +import com.huawei.boostkit.spark.exception.OmniMVException +import com.huawei.boostkit.spark.util.RewriteHelper.{disableCachePlugin, enableCachePlugin} +import com.huawei.boostkit.spark.util.ViewMetadata +import java.io.File +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.IOUtils +import org.json4s.DefaultFormats +import org.json4s.jackson.Json +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} +import org.scalatest.funsuite.AnyFunSuite +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.optimizer.rules.RewriteSuite +import org.apache.spark.sql.execution.command.WashOutStrategy + +class WashOutMVSuite extends WashOutBase { + + test("view count accumulate") { + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS view_count; + |""".stripMargin) + spark.sql( + """ + |CREATE MATERIALIZED VIEW IF NOT EXISTS view_count + | PARTITIONED BY (longtype,doubletype,datetype,stringtype) + |AS + |SELECT c1.*,e1.empname,d1.deptname FROM + |emps e1 JOIN column_type c1 JOIN depts d1 + |ON e1.empid=c1.empid + |AND c1.deptno=d1.deptno + |; + |""".stripMargin + ) + assert(ViewMetadata.viewCnt.get("default.view_count")(0) == 0) + + val sql1 = + """ + |SELECT c1.*,e1.empname,d1.deptname FROM + |emps e1 JOIN column_type c1 JOIN depts d1 + |ON e1.empid=c1.empid + |AND c1.deptno=d1.deptno + |""".stripMargin + RewriteSuite.comparePlansAndRows(sql1, "default", "view_count", noData = false) + assert(ViewMetadata.viewCnt.get("default.view_count")(0) == 1) + + val sql2 = + """ + |SELECT c1.*,e1.empname,d1.deptname,e1.salary FROM + |emps e1 JOIN column_type c1 JOIN depts d1 + |ON e1.empid=c1.empid + |AND c1.deptno=d1.deptno + |""".stripMargin + RewriteSuite.compareNotRewriteAndRows(sql2, noData = false) + assert(ViewMetadata.viewCnt.get("default.view_count")(0) == 1) + + RewriteSuite.comparePlansAndRows(sql1, "default", "view_count", noData = false) + assert(ViewMetadata.viewCnt.get("default.view_count")(0) == 2) + + spark.sql( + """ + |DROP MATERIALIZED VIEW IF EXISTS view_count; + |""".stripMargin) + } + + test("wash out mv by reserve quantity.") { + spark.sessionState.conf.setConfString( + "spark.sql.omnimv.washout.reserve.quantity.byViewCnt", "3") + val reserveQuantity = OmniMVPluginConfig.getConf.reserveViewQuantityByViewCount + spark.sql("WASH OUT ALL MATERIALIZED VIEW") + val random = new Random() + val viewsInfo = mutable.ArrayBuffer[(String, Array[Int])]() + for (i <- 1 to 10) { + val sql = + f""" + |SELECT * FROM COLUMN_TYPE WHERE empid=${i}0; + |""".stripMargin + // create mv + spark.sql( + f""" + |CREATE MATERIALIZED VIEW IF NOT EXISTS wash_mv$i + |AS + |$sql + |""".stripMargin) + val curTimes = random.nextInt(10) + viewsInfo.append( + (ViewMetadata.getDefaultDatabase + f".wash_mv$i", Array(curTimes, i))) + // rewrite sql curTimes. + for (_ <- 1 to curTimes) { + RewriteSuite.comparePlansAndRows(sql, "default", s"wash_mv$i", noData = true) + } + } + val toDel = viewsInfo.sorted { + (x: (String, Array[Int]), y: (String, Array[Int])) => { + if (y._2(0) != x._2(0)) { + y._2(0).compare(x._2(0)) + } else { + y._2(1).compare(x._2(1)) + } + } + }.slice(reserveQuantity, viewsInfo.size).map(_._1) + spark.sql(f"WASH OUT MATERIALIZED VIEW USING " + + f"${WashOutStrategy.RESERVE_QUANTITY_BY_VIEW_COUNT} $reserveQuantity") + val data = mutable.Map[String, Array[Long]]() + loadData(new Path( + new Path(ViewMetadata.metadataPath, + ViewMetadata.getDefaultDatabase), + ViewMetadata.getViewCntPath), data) + data.foreach { + info => + assert(!toDel.contains(info._1)) + } + } + + test("wash out mv by unused days.") { + spark.sql("WASH OUT ALL MATERIALIZED VIEW") + val unUsedDays = OmniMVPluginConfig.getConf.minimumUnusedDaysForWashOut + for (i <- 1 to 5) { + val sql = + f""" + |SELECT * FROM COLUMN_TYPE WHERE empid=${i}0; + |""".stripMargin + // create mv + spark.sql( + f""" + |CREATE MATERIALIZED VIEW IF NOT EXISTS wash_mv$i + |AS + |$sql + |""".stripMargin) + } + var data = mutable.Map[String, Array[Long]]() + val path = new Path(new Path( + ViewMetadata.metadataPath, ViewMetadata.getDefaultDatabase), ViewMetadata.getViewCntPath) + loadData(path, data) + var cnt = 2 + val toDel = mutable.Set[String]() + data.foreach { + a => + if (cnt > 0) { + // update mv used timestamp. + data.update(a._1, Array(1, 0)) + cnt -= 1 + toDel += a._1 + } + } + saveData(path, data) + ViewMetadata.forceLoad() + spark.sql(f"WASH OUT MATERIALIZED VIEW USING " + + f"${WashOutStrategy.UNUSED_DAYS} $unUsedDays") + data = mutable.Map[String, Array[Long]]() + loadData(path, data) + data.foreach { + info => + assert(!toDel.contains(info._1)) + } + } + + test("wash out mv by space consumed.") { + spark.sql("WASH OUT ALL MATERIALIZED VIEW") + val dropQuantity = 2 + for (i <- 1 to 10) { + val sql = + f""" + |SELECT * FROM COLUMN_TYPE WHERE empid=$i; + |""".stripMargin + // create mv + spark.sql( + f""" + |CREATE MATERIALIZED VIEW IF NOT EXISTS wash_mv$i + |AS + |$sql + |""".stripMargin) + } + spark.sql("WASH OUT MATERIALIZED VIEW USING " + + f"${WashOutStrategy.DROP_QUANTITY_BY_SPACE_CONSUMED} $dropQuantity") + val data = mutable.Map[String, Array[Long]]() + val path = new Path(new Path( + ViewMetadata.metadataPath, ViewMetadata.getDefaultDatabase), ViewMetadata.getViewCntPath) + loadData(path, data) + val dropList = List(1, 4) + dropList.foreach { + a => + assert(!data.contains(f"${ViewMetadata.getDefaultDatabase}.wash_mv$a")) + } + } + + test("wash out all mv") { + spark.sql("WASH OUT ALL MATERIALIZED VIEW") + for (i <- 1 to 5) { + val sql = + f""" + |SELECT * FROM COLUMN_TYPE WHERE empid=${i}0; + |""".stripMargin + // create mv + spark.sql( + f""" + |CREATE MATERIALIZED VIEW IF NOT EXISTS wash_mv$i + |AS + |$sql + |""".stripMargin) + } + var data = mutable.Map[String, Array[Long]]() + loadData(new Path( + new Path(ViewMetadata.metadataPath, + ViewMetadata.getDefaultDatabase), + ViewMetadata.getViewCntPath), data) + assert(data.size == 5) + spark.sql("WASH OUT ALL MATERIALIZED VIEW") + data = mutable.Map[String, Array[Long]]() + loadData(new Path( + new Path(ViewMetadata.metadataPath, + ViewMetadata.getDefaultDatabase), + ViewMetadata.getViewCntPath), data) + assert(data.isEmpty) + } + + test("auto wash out") { + spark.sessionState.conf.setConfString( + "spark.sql.omnimv.washout.unused.day", "0") + spark.sessionState.conf.setConfString( + "spark.sql.omnimv.washout.reserve.quantity.byViewCnt", "1") + spark.sessionState.conf.setConfString( + "spark.sql.omnimv.washout.drop.quantity.bySpaceConsumed", "1") + spark.sessionState.conf.setConfString( + "spark.sql.omnimv.washout.automatic.time.interval", "0") + spark.sessionState.conf.setConfString( + "spark.sql.omnimv.washout.automatic.view.quantity", "1") + spark.sessionState.conf.setConfString( + "spark.sql.omnimv.washout.automatic.enable", "true") + spark.sessionState.conf.setConfString( + "spark.sql.omnimv.washout.automatic.checkTime.interval", "0") + spark.sql( + f""" + |CREATE MATERIALIZED VIEW IF NOT EXISTS wash_mv1 + |AS + |SELECT * FROM COLUMN_TYPE WHERE empid=100; + |""".stripMargin) + spark.sql( + f""" + |CREATE MATERIALIZED VIEW IF NOT EXISTS wash_mv2 + |AS + |SELECT * FROM COLUMN_TYPE WHERE empid=200; + |""".stripMargin) + val sql = + """ + |SELECT * FROM COLUMN_TYPE WHERE empid=100; + |""".stripMargin + val plan = spark.sql(sql).queryExecution.optimizedPlan + assert(RewriteSuite.isNotRewritedByMV(plan)) + spark.sessionState.conf.setConfString( + "spark.sql.omnimv.washout.automatic.enable", "false") + } +} + +class WashOutBase extends AnyFunSuite + with BeforeAndAfterAll + with BeforeAndAfterEach { + + System.setProperty("HADOOP_USER_NAME", "root") + lazy val spark: SparkSession = SparkSession.builder().master("local") + .config("spark.sql.extensions", "com.huawei.boostkit.spark.OmniMV") + .config("hive.exec.dynamic.partition.mode", "nonstrict") + .config("spark.ui.port", "4050") + // .config("spark.sql.planChangeLog.level", "WARN") + .config("spark.sql.omnimv.logLevel", "WARN") + .config("spark.sql.omnimv.dbs", "default") + .config("spark.sql.omnimv.metadata.initbyquery.enable", "false") + .config("hive.in.test", "true") + .config("spark.sql.omnimv.metadata.path", "./user/omnimv/metadata") + .config("spark.sql.omnimv.washout.automatic.enable", "false") + .enableHiveSupport() + .getOrCreate() + spark.sparkContext.setLogLevel("WARN") + lazy val catalog: SessionCatalog = spark.sessionState.catalog + + override def beforeEach(): Unit = { + enableCachePlugin() + } + + override def beforeAll(): Unit = { + preCreateTable() + } + + def preDropTable(): Unit = { + if (File.separatorChar == '\\') { + return + } + spark.sql("DROP TABLE IF EXISTS locations").show() + spark.sql("DROP TABLE IF EXISTS depts").show() + spark.sql("DROP TABLE IF EXISTS emps").show() + spark.sql("DROP TABLE IF EXISTS column_type").show() + } + + def preCreateTable(): Unit = { + disableCachePlugin() + preDropTable() + if (catalog.tableExists(TableIdentifier("locations"))) { + enableCachePlugin() + return + } + spark.sql( + """ + |CREATE TABLE IF NOT EXISTS locations( + | locationid INT, + | state STRING + |); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE locations VALUES(1,'state1'); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE locations VALUES(2,'state2'); + |""".stripMargin + ) + + spark.sql( + """ + |CREATE TABLE IF NOT EXISTS depts( + | deptno INT, + | deptname STRING + |); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE depts VALUES(1,'deptname1'); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE depts VALUES(2,'deptname2'); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE depts VALUES(3,'deptname3'); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE depts VALUES(4,'deptname4'); + |""".stripMargin + ) + + spark.sql( + """ + |CREATE TABLE IF NOT EXISTS emps( + | empid INT, + | deptno INT, + | locationid INT, + | empname STRING, + | salary DOUBLE + |); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE emps VALUES(1,1,1,'empname1',1.0); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE emps VALUES(2,2,2,'empname2',2.0); + |""".stripMargin + ) + + spark.sql( + """ + |INSERT INTO TABLE emps VALUES(3,null,3,'empname3',3.0); + |""".stripMargin + ) + + spark.sql( + """ + |CREATE TABLE IF NOT EXISTS column_type( + | empid INT, + | deptno INT, + | locationid INT, + | booleantype BOOLEAN, + | bytetype BYTE, + | shorttype SHORT, + | integertype INT, + | longtype LONG, + | floattype FLOAT, + | doubletype DOUBLE, + | datetype DATE, + | timestamptype TIMESTAMP, + | stringtype STRING, + | decimaltype DECIMAL + |); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE column_type VALUES( + | 1,1,1,TRUE,1,1,1,1,1.0,1.0, + | DATE '2022-01-01', + | TIMESTAMP '2022-01-01', + | 'stringtype1',1.0 + |); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE column_type VALUES( + | 2,2,2,TRUE,2,2,2,2,2.0,2.0, + | DATE '2022-02-02', + | TIMESTAMP '2022-02-02', + | 'stringtype2',2.0 + |); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE column_type VALUES( + | 1,1,1,null,null,null,null,null,null,null, + | null, + | null, + | null,null + |); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE column_type VALUES( + | 3,3,3,TRUE,3,3,3,3,3.0,3.0, + | DATE '2022-03-03', + | TIMESTAMP '2022-03-03', + | 'stringtype3',null + |); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE column_type VALUES( + | 4,4,4,TRUE,4,4,4,4,4.0,4.0, + | DATE '2022-04-04', + | TIMESTAMP '2022-04-04', + | null,4.0 + |); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE column_type VALUES( + | 4,4,4,TRUE,4,4,4,4,4.0,4.0, + | DATE '2022-04-04', + | null, + | null,4.0 + |); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE column_type VALUES( + | 4,4,4,TRUE,4,4,4,4,4.0,4.0, + | DATE '2022-04-04', + | TIMESTAMP '2022-04-04', + | 'stringtype4',null + |); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE column_type VALUES( + | 1,1,2,TRUE,1,1,1,1,1.0,1.0, + | DATE '2022-01-01', + | TIMESTAMP '2022-01-01', + | 'stringtype1',1.0 + |); + |""".stripMargin + ) + spark.sql( + """ + |INSERT INTO TABLE column_type VALUES( + | 1,1,2,TRUE,1,1,1,1,1.0,1.0, + | DATE '2022-01-02', + | TIMESTAMP '2022-01-01', + | 'stringtype1',1.0 + |); + |""".stripMargin + ) + enableCachePlugin() + } + + def loadData[K: Manifest, V: Manifest](file: Path, + buffer: mutable.Map[K, V]): Unit = { + try { + val fs = file.getFileSystem(new Configuration) + val is = fs.open(file) + val content = IOUtils.readFullyToByteArray(is) + .map(_.toChar.toString).reduce((a, b) => a + b) + Json(DefaultFormats).read[mutable.Map[K, V]](content).foreach { + data => + buffer += data + } + is.close() + } catch { + case _: Throwable => + throw OmniMVException("load data failed.") + } + } + + def saveData[K: Manifest, V: Manifest](file: Path, + buffer: mutable.Map[K, V]): Unit = { + try { + val fs = file.getFileSystem(new Configuration) + val os = fs.create(file, true) + val bytes = Json(DefaultFormats).write(buffer).getBytes + os.write(bytes) + os.close() + } catch { + case _: Throwable => + throw OmniMVException("save data failed.") + } + } +} diff --git a/omnicache/omnicache-spark-extension/pom.xml b/omnimv/omnimv-spark-extension/pom.xml similarity index 86% rename from omnicache/omnicache-spark-extension/pom.xml rename to omnimv/omnimv-spark-extension/pom.xml index 28712374061c30dc3b86ffa62244d1c851fd807e..bc9b21947dc4e2fc5ad95b8569ca18c89a0ad728 100644 --- a/omnicache/omnicache-spark-extension/pom.xml +++ b/omnimv/omnimv-spark-extension/pom.xml @@ -6,9 +6,9 @@ 4.0.0 com.huawei.kunpeng - boostkit-omnicache-spark-parent + boostkit-omnimv-spark-parent pom - 3.1.1-1.0.0 + ${omnimv.version} plugin @@ -18,6 +18,7 @@ BoostKit Spark MaterializedView Sql Engine Extension Parent Pom + 3.1.1-1.1.0 2.12.10 2.12 1.8 @@ -35,11 +36,18 @@ 3.1.2 1.4.11 8.29 + 4.0.2 + true + + com.esotericsoftware + kryo-shaded + ${kryo-shaded.version} + org.apache.spark spark-sql_${scala.binary.version} @@ -57,23 +65,6 @@ - - org.apache.spark - spark-core_${scala.binary.version} - ${spark.version} - test-jar - test - - - org.apache.hadoop - hadoop-client - - - org.apache.curator - curator-recipes - - - junit junit diff --git a/omnicache/omnicache-spark-extension/scalastyle-config.xml b/omnimv/omnimv-spark-extension/scalastyle-config.xml similarity index 100% rename from omnicache/omnicache-spark-extension/scalastyle-config.xml rename to omnimv/omnimv-spark-extension/scalastyle-config.xml diff --git a/omnioperator/omniop-openlookeng-extension/pom.xml b/omnioperator/omniop-openlookeng-extension/pom.xml index d59ef4ecace60bf70383397ddf80ddd41b8913fd..513d6cad0fbc43e99800e0ecdc497cea48b5eddb 100644 --- a/omnioperator/omniop-openlookeng-extension/pom.xml +++ b/omnioperator/omniop-openlookeng-extension/pom.xml @@ -21,7 +21,7 @@ 3.2.0-8 3.1.2-1 2.11.4 - 1.1.0 + 1.3.0 @@ -290,5 +290,4 @@ - - + \ No newline at end of file diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/OmniLocalExecutionPlanner.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/OmniLocalExecutionPlanner.java index fd1240a20721eecc420c9eaf3e9b64d34c0382b1..b7a4718da9858567222f2e5185eacc5242d8f76d 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/OmniLocalExecutionPlanner.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/OmniLocalExecutionPlanner.java @@ -50,6 +50,7 @@ import io.prestosql.operator.OutputFactory; import io.prestosql.operator.PagesIndex; import io.prestosql.operator.PartitionFunction; import io.prestosql.operator.PartitionedLookupSourceFactory; +import io.prestosql.operator.PartitionedOutputOperator; import io.prestosql.operator.PipelineExecutionStrategy; import io.prestosql.operator.ScanFilterAndProjectOperator; import io.prestosql.operator.SourceOperatorFactory; @@ -125,6 +126,7 @@ import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.NodeRef; import io.prestosql.statestore.StateStoreProvider; import io.prestosql.statestore.listener.StateStoreListenerManager; +import nova.hetu.olk.memory.OpenLooKengMemoryManager; import nova.hetu.olk.operator.AbstractOmniOperatorFactory; import nova.hetu.olk.operator.AggregationOmniOperator; import nova.hetu.olk.operator.BuildOffHeapOmniOperator; @@ -138,8 +140,6 @@ import nova.hetu.olk.operator.LimitOmniOperator; import nova.hetu.olk.operator.LocalMergeSourceOmniOperator; import nova.hetu.olk.operator.LookupJoinOmniOperators; import nova.hetu.olk.operator.MergeOmniOperator; -import nova.hetu.olk.operator.PartitionedOutputOmniOperator; -import nova.hetu.olk.operator.ScanFilterAndProjectOmniOperator; import nova.hetu.olk.operator.TopNOmniOperator; import nova.hetu.olk.operator.WindowOmniOperator; import nova.hetu.olk.operator.filterandproject.FilterAndProjectOmniOperator; @@ -149,7 +149,6 @@ import nova.hetu.olk.operator.localexchange.LocalExchangeSinkOmniOperator; import nova.hetu.olk.operator.localexchange.LocalExchangeSourceOmniOperator; import nova.hetu.olk.operator.localexchange.OmniLocalExchange; import nova.hetu.olk.tool.OperatorUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; import nova.hetu.omniruntime.constants.FunctionType; import nova.hetu.omniruntime.type.DataType; @@ -384,7 +383,7 @@ public class OmniLocalExecutionPlanner List partitionedSourceOrder, OutputBuffer outputBuffer, Optional feederCTEId, Optional feederCTEParentId, Map cteCtx) { - VecAllocatorHelper.createTaskLevelAllocator(taskContext); + OpenLooKengMemoryManager.setGlobalMemoryLimit(); List outputLayout = partitioningScheme.getOutputLayout(); if (partitioningScheme.getPartitioning().getHandle().equals(FIXED_BROADCAST_DISTRIBUTION) @@ -440,13 +439,25 @@ public class OmniLocalExecutionPlanner nullChannel = OptionalInt.of(outputLayout.indexOf(getOnlyElement(partitioningColumns))); } boolean isHashPrecomputed = partitioningScheme.getHashColumn().isPresent(); - return plan(taskContext, stageExecutionDescriptor, plan, outputLayout, types, partitionedSourceOrder, + return plan( + taskContext, + stageExecutionDescriptor, + plan, + outputLayout, + types, + partitionedSourceOrder, outputBuffer, - new PartitionedOutputOmniOperator.PartitionedOutputOmniFactory(partitionFunction, partitionChannels, - partitionConstants, partitioningScheme.isReplicateNullsAndAny(), nullChannel, outputBuffer, - maxPagePartitioningBufferSize, partitioningScheme.getBucketToPartition().get(), - isHashPrecomputed, partitionChannelTypes), - feederCTEId, feederCTEParentId, cteCtx); + new PartitionedOutputOperator.PartitionedOutputFactory( + partitionFunction, + partitionChannels, + partitionConstants, + partitioningScheme.isReplicateNullsAndAny(), + nullChannel, + outputBuffer, + maxPagePartitioningBufferSize), + feederCTEId, + feederCTEParentId, + cteCtx); } @Override @@ -455,7 +466,7 @@ public class OmniLocalExecutionPlanner OutputBuffer outputBuffer, OutputFactory outputOperatorFactory, Optional feederCTEId, Optional feederCTEParentId, Map cteCtx) { - VecAllocatorHelper.createTaskLevelAllocator(taskContext); + OpenLooKengMemoryManager.setGlobalMemoryLimit(); Session session = taskContext.getSession(); LocalExecutionPlanContext context = new OmniLocalExecutionPlanContext(taskContext, types, metadata, dynamicFilterCacheManager, feederCTEId, feederCTEParentId, cteCtx); @@ -803,27 +814,18 @@ public class OmniLocalExecutionPlanner Supplier cursorProcessor = expressionCompiler.compileCursorProcessor(translatedFilter, translatedProjections, sourceNode.getId()); SourceOperatorFactory operatorFactory; - if (useOmniOperator) { - operatorFactory = new ScanFilterAndProjectOmniOperator.ScanFilterAndProjectOmniOperatorFactory( - context.getSession(), context.getNextOperatorId(), planNodeId, sourceNode, - pageSourceProvider, cursorProcessor, pageProcessor, table, columns, dynamicFilter, - projections.stream().map(expression -> expression.getType()).collect(toImmutableList()), - stateStoreProvider, metadata, dynamicFilterCacheManager, - getFilterAndProjectMinOutputPageSize(session), - getFilterAndProjectMinOutputPageRowCount(session), strategy, reuseTableScanMappingId, - spillEnabled, Optional.of(spillerFactory), spillerThreshold, consumerTableScanNodeCount, - inputTypes, (OmniLocalExecutionPlanContext) context); - } - else { - operatorFactory = new ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory( - context.getSession(), context.getNextOperatorId(), planNodeId, sourceNode, - pageSourceProvider, cursorProcessor, pageProcessor, table, columns, dynamicFilter, - projections.stream().map(expression -> expression.getType()).collect(toImmutableList()), - stateStoreProvider, metadata, dynamicFilterCacheManager, - getFilterAndProjectMinOutputPageSize(session), - getFilterAndProjectMinOutputPageRowCount(session), strategy, reuseTableScanMappingId, - spillEnabled, Optional.of(spillerFactory), spillerThreshold, consumerTableScanNodeCount); - } + + pageProcessor = expressionCompiler.compilePageProcessor(translatedFilter, translatedProjections, + Optional.of(context.getStageId() + "_" + planNodeId)); + operatorFactory = new ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory( + context.getSession(), context.getNextOperatorId(), planNodeId, sourceNode, + pageSourceProvider, cursorProcessor, pageProcessor, table, columns, dynamicFilter, + projections.stream().map(expression -> expression.getType()).collect(toImmutableList()), + stateStoreProvider, metadata, dynamicFilterCacheManager, + getFilterAndProjectMinOutputPageSize(session), + getFilterAndProjectMinOutputPageRowCount(session), strategy, reuseTableScanMappingId, + spillEnabled, Optional.of(spillerFactory), spillerThreshold, consumerTableScanNodeCount); + return new PhysicalOperation(operatorFactory, outputMappings, context, stageExecutionDescriptor.isScanGroupedExecution(sourceNode.getId()) ? GROUPED_EXECUTION diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/ByteArrayOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/ByteArrayOmniBlock.java index 557ccca7424d6e1b10b66ba5f914f9d778673ce1..c9bed4649e12d51fc0dbf5a2d2f5566c747068d6 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/ByteArrayOmniBlock.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/ByteArrayOmniBlock.java @@ -21,21 +21,15 @@ import io.prestosql.spi.block.ByteArrayBlockEncoding; import io.prestosql.spi.util.BloomFilter; import nova.hetu.omniruntime.vector.BooleanVec; import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecAllocator; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - -import java.util.Optional; import java.util.function.BiConsumer; import java.util.function.Function; import static io.airlift.slice.SizeOf.sizeOf; import static io.prestosql.spi.block.BlockUtil.checkArrayRange; import static io.prestosql.spi.block.BlockUtil.checkValidRegion; -import static io.prestosql.spi.block.BlockUtil.compactArray; import static io.prestosql.spi.block.BlockUtil.countUsedPositions; -import static nova.hetu.olk.tool.BlockUtils.compactVec; /** * The type Byte array omni block. @@ -47,34 +41,15 @@ public class ByteArrayOmniBlock { private static final int INSTANCE_SIZE = ClassLayout.parseClass(ByteArrayOmniBlock.class).instanceSize(); - private final VecAllocator vecAllocator; - - private final int arrayOffset; - private final int positionCount; - @Nullable - private final byte[] valueIsNull; - private final BooleanVec values; private final long sizeInBytes; private final long retainedSizeInBytes; - /** - * Instantiates a new Byte array omni block. - * - * @param vecAllocator the vector allocator - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - public ByteArrayOmniBlock(VecAllocator vecAllocator, int positionCount, Optional valueIsNull, - byte[] values) - { - this(vecAllocator, 0, positionCount, valueIsNull.orElse(null), values); - } + private boolean hasNull; /** * Instantiates a new Byte array omni block. @@ -84,34 +59,24 @@ public class ByteArrayOmniBlock */ public ByteArrayOmniBlock(int positionCount, BooleanVec values) { - this(positionCount, values.hasNullValue() ? Optional.of(values.getRawValueNulls()) : Optional.empty(), values); - } - - /** - * Instantiates a new Byte array omni block. - * - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - public ByteArrayOmniBlock(int positionCount, Optional valueIsNull, BooleanVec values) - { - this(values.getOffset(), positionCount, valueIsNull.orElse(null), values); + this.positionCount = positionCount; + this.values = values; + this.sizeInBytes = (Byte.BYTES + Byte.BYTES) * (long) positionCount; + this.retainedSizeInBytes = INSTANCE_SIZE + this.values.getCapacityInBytes(); + this.hasNull = values.hasNull(); } /** * Instantiates a new Byte array omni block. * - * @param vecAllocator the vector allocator * @param arrayOffset the array offset * @param positionCount the position count * @param valueIsNull the value is null * @param values the values */ - public ByteArrayOmniBlock(VecAllocator vecAllocator, int arrayOffset, int positionCount, byte[] valueIsNull, + public ByteArrayOmniBlock(int arrayOffset, int positionCount, byte[] valueIsNull, byte[] values) { - this.vecAllocator = vecAllocator; if (arrayOffset < 0) { throw new IllegalArgumentException("arrayOffset is negative"); } @@ -125,7 +90,7 @@ public class ByteArrayOmniBlock throw new IllegalArgumentException("values length is less than positionCount"); } - this.values = new BooleanVec(vecAllocator, positionCount); + this.values = new BooleanVec(positionCount); this.values.put(values, 0, arrayOffset, positionCount); if (valueIsNull != null && valueIsNull.length - arrayOffset < positionCount) { @@ -134,49 +99,9 @@ public class ByteArrayOmniBlock if (valueIsNull != null) { this.values.setNulls(0, valueIsNull, arrayOffset, positionCount); - this.valueIsNull = compactArray(valueIsNull, arrayOffset, positionCount); - } - else { - this.valueIsNull = null; + this.hasNull = true; } - this.arrayOffset = 0; - - sizeInBytes = (Byte.BYTES + Byte.BYTES) * (long) positionCount; - retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + this.values.getCapacityInBytes(); - } - - /** - * Instantiates a new Byte array omni block. - * - * @param arrayOffset the array offset - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - ByteArrayOmniBlock(int arrayOffset, int positionCount, byte[] valueIsNull, BooleanVec values) - { - this.vecAllocator = values.getAllocator(); - if (arrayOffset < 0) { - throw new IllegalArgumentException("arrayOffset is negative"); - } - this.arrayOffset = arrayOffset; - - if (positionCount < 0) { - throw new IllegalArgumentException("positionCount is negative"); - } - this.positionCount = positionCount; - - if (values.getSize() < positionCount) { - throw new IllegalArgumentException("values length is less than positionCount"); - } - this.values = values; - - if (valueIsNull != null && valueIsNull.length - arrayOffset < positionCount) { - throw new IllegalArgumentException("isNull length is less than positionCount"); - } - this.valueIsNull = valueIsNull; - sizeInBytes = (Byte.BYTES + Byte.BYTES) * (long) positionCount; retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + this.values.getCapacityInBytes(); } @@ -227,9 +152,6 @@ public class ByteArrayOmniBlock public void retainedBytesForEachPart(BiConsumer consumer) { consumer.accept(values.get(0, positionCount), (long) values.getCapacityInBytes()); - if (valueIsNull != null) { - consumer.accept(valueIsNull, sizeOf(valueIsNull)); - } consumer.accept(this, (long) INSTANCE_SIZE); } @@ -270,14 +192,14 @@ public class ByteArrayOmniBlock @Override public boolean mayHaveNull() { - return valueIsNull != null; + return hasNull; } @Override public boolean isNull(int position) { checkReadablePosition(position); - return valueIsNull != null && valueIsNull[position + arrayOffset] == Vec.NULL; + return values.isNull(position); } @Override @@ -292,7 +214,7 @@ public class ByteArrayOmniBlock public Block getSingleValueBlock(int position) { checkReadablePosition(position); - return new ByteArrayOmniBlock(vecAllocator, 0, 1, isNull(position) ? new byte[]{Vec.NULL} : null, + return new ByteArrayOmniBlock(0, 1, isNull(position) ? new byte[]{Vec.NULL} : null, new byte[]{(values.get(position) ? (byte) 1 : (byte) 0)}); } @@ -300,20 +222,16 @@ public class ByteArrayOmniBlock public Block copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); - byte[] newValueIsNull = null; BooleanVec newValues = values.copyPositions(positions, offset, length); - if (valueIsNull != null) { - newValueIsNull = newValues.getRawValueNulls(); - } - return new ByteArrayOmniBlock(0, length, newValueIsNull, newValues); + return new ByteArrayOmniBlock(length, newValues); } @Override public Block getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); - BooleanVec newValues = values.slice(positionOffset, positionOffset + length); - return new ByteArrayOmniBlock(newValues.getOffset(), length, valueIsNull, newValues); + BooleanVec newValues = values.slice(positionOffset, length); + return new ByteArrayOmniBlock(length, newValues); } @Override @@ -321,15 +239,9 @@ public class ByteArrayOmniBlock { checkValidRegion(getPositionCount(), positionOffset, length); - BooleanVec newValues = compactVec(values, positionOffset, length); - byte[] newValueIsNull = valueIsNull == null - ? null - : compactArray(valueIsNull, positionOffset + arrayOffset, length); - - if (newValueIsNull == valueIsNull && newValues == values) { - return this; - } - return new ByteArrayOmniBlock(0, length, newValueIsNull, newValues); + BooleanVec newValues = values.slice(positionOffset, length); + values.close(); + return new ByteArrayOmniBlock(length, newValues); } @Override @@ -368,7 +280,7 @@ public class ByteArrayOmniBlock { int matchCount = 0; for (int i = 0; i < positionCount; i++) { - if (valueIsNull != null && valueIsNull[positions[i] + arrayOffset] == Vec.NULL) { + if (values.isNull(positions[i])) { if (test.apply(null)) { matchedPositions[matchCount++] = positions[i]; } @@ -383,7 +295,7 @@ public class ByteArrayOmniBlock @Override public Byte get(int position) { - if (valueIsNull != null && valueIsNull[position + arrayOffset] == Vec.NULL) { + if (values.isNull(position)) { return null; } return values.get(position) ? (byte) 1 : (byte) 0; diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/DictionaryOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/DictionaryOmniBlock.java index 7ec5df5e4093513eedcd41a2d8e772851062bb40..55eb9ee5a69f878f7561503b772db8c7c71527ac 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/DictionaryOmniBlock.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/DictionaryOmniBlock.java @@ -21,8 +21,15 @@ import io.prestosql.spi.PrestoException; import io.prestosql.spi.StandardErrorCode; import io.prestosql.spi.block.Block; import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.ByteArrayBlock; +import io.prestosql.spi.block.DictionaryBlock; import io.prestosql.spi.block.DictionaryBlockEncoding; import io.prestosql.spi.block.DictionaryId; +import io.prestosql.spi.block.Int128ArrayBlock; +import io.prestosql.spi.block.IntArrayBlock; +import io.prestosql.spi.block.LongArrayBlock; +import io.prestosql.spi.block.ShortArrayBlock; +import io.prestosql.spi.block.VariableWidthBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.vector.BooleanVec; @@ -39,6 +46,7 @@ import nova.hetu.omniruntime.vector.VecEncoding; import org.openjdk.jol.info.ClassLayout; import java.util.Arrays; +import java.util.Optional; import java.util.function.BiConsumer; import java.util.function.Function; @@ -49,6 +57,7 @@ import static io.prestosql.spi.block.BlockUtil.checkValidPositions; import static io.prestosql.spi.block.BlockUtil.checkValidRegion; import static io.prestosql.spi.block.BlockUtil.countUsedPositions; import static io.prestosql.spi.block.DictionaryId.randomDictionaryId; +import static java.lang.Double.doubleToLongBits; import static java.lang.Math.min; import static java.util.Objects.requireNonNull; import static nova.hetu.olk.tool.OperatorUtils.buildRowOmniBlock; @@ -163,7 +172,7 @@ public class DictionaryOmniBlock this.positionCount = positionCount; this.dictionaryVec = new DictionaryVec(dictionary, ids); - this.dictionary = buildBlock(dictionaryVec.getDictionary()); + this.dictionary = createFlatBlock(dictionary.getType().getId(), dictionary); this.ids = ids; this.dictionarySourceId = requireNonNull(dictionarySourceId, "dictionarySourceId is null"); this.retainedSizeInBytes = INSTANCE_SIZE + this.dictionary.getRetainedSizeInBytes() + sizeOf(ids); @@ -184,16 +193,16 @@ public class DictionaryOmniBlock DictionaryId dictionarySourceId) { this.positionCount = dictionaryVec.getSize(); - this.idsOffset = dictionaryVec.getOffset(); - this.dictionary = buildBlock(dictionaryVec.getDictionary()); - this.ids = dictionaryVec.getIds(); + this.idsOffset = 0; + this.dictionary = expandDictionary(dictionaryVec); + this.ids = getIds(positionCount); this.dictionarySourceId = requireNonNull(dictionarySourceId, "dictionarySourceId is null"); this.retainedSizeInBytes = INSTANCE_SIZE + dictionary.getRetainedSizeInBytes() + sizeOf(ids); this.dictionaryVec = dictionaryVec; if (dictionaryIsCompacted) { this.sizeInBytes = this.retainedSizeInBytes; - this.uniqueIds = dictionary.getPositionCount(); + this.uniqueIds = dictionaryVec.getSize(); } } @@ -204,7 +213,7 @@ public class DictionaryOmniBlock VecEncoding vecEncoding = dictionary.getEncoding(); switch (vecEncoding) { case OMNI_VEC_ENCODING_FLAT: - dictionaryBlock = createFlatBlock(dataType.getId(), dictionary); + dictionaryBlock = createFlatOmniBlock(dataType.getId(), dictionary); break; case OMNI_VEC_ENCODING_DICTIONARY: dictionaryBlock = new DictionaryOmniBlock((DictionaryVec) dictionary, false, randomDictionaryId()); @@ -218,7 +227,7 @@ public class DictionaryOmniBlock return dictionaryBlock; } - private static Block createFlatBlock(DataType.DataTypeId dataTypeId, Vec dictionary) + private static Block createFlatOmniBlock(DataType.DataTypeId dataTypeId, Vec dictionary) { Block dictionaryBlock; switch (dataTypeId) { @@ -252,6 +261,123 @@ public class DictionaryOmniBlock return dictionaryBlock; } + private static Block createVariableWidthBlock(Vec vec, int positionCount) + { + VarcharVec varcharVec = (VarcharVec) vec; + Slice slice = Slices.wrappedBuffer(varcharVec.get(0, positionCount)); + int[] offsets = new int[positionCount + 1]; + for (int i = 0; i < positionCount; i++) { + offsets[i + 1] = offsets[i] + varcharVec.getDataLength(i); + } + + return new VariableWidthBlock(positionCount, slice, offsets, + varcharVec.hasNull() + ? Optional.of(varcharVec.getValuesNulls(0, positionCount)) + : Optional.empty()); + } + + private static Block createInt128ArrayBlock(Vec vec, int positionCount) + { + Decimal128Vec decimal128Vec = (Decimal128Vec) vec; + return new Int128ArrayBlock(positionCount, Optional.of(decimal128Vec.getValuesNulls(0, positionCount)), + decimal128Vec.get(0, positionCount)); + } + + private static Block createDoubleArrayBlock(Vec vec, int positionCount) + { + DoubleVec doubleVec = (DoubleVec) vec; + boolean[] valuesNulls = doubleVec.getValuesNulls(0, positionCount); + long[] values = new long[positionCount]; + for (int j = 0; j < positionCount; j++) { + if (!vec.isNull(j)) { + values[j] = doubleToLongBits(doubleVec.get(j)); + } + } + return new LongArrayBlock(positionCount, Optional.of(valuesNulls), values); + } + + private static Block createShortArrayBlock(Vec vec, int positionCount) + { + ShortVec shortVec = (ShortVec) vec; + return new ShortArrayBlock(positionCount, Optional.of(shortVec.getValuesNulls(0, positionCount)), + shortVec.get(0, positionCount)); + } + + private static Block createLongArrayBlock(Vec vec, int positionCount) + { + LongVec longVec = (LongVec) vec; + return new LongArrayBlock(positionCount, Optional.of(longVec.getValuesNulls(0, positionCount)), + longVec.get(0, positionCount)); + } + + private static Block createIntArrayBlock(Vec vec, int positionCount) + { + IntVec intVec = (IntVec) vec; + return new IntArrayBlock(positionCount, Optional.of(intVec.getValuesNulls(0, positionCount)), + intVec.get(0, positionCount)); + } + + private static Block createByteArrayBlock(Vec vec, int positionCount) + { + BooleanVec booleanVec = (BooleanVec) vec; + byte[] bytes = booleanVec.getValuesBuf().getBytes(0, positionCount); + return new ByteArrayBlock(positionCount, Optional.of(booleanVec.getValuesNulls(0, positionCount)), + bytes); + } + + private static Block createFlatBlock(DataType.DataTypeId dataTypeId, Vec dictionary) + { + switch (dataTypeId) { + case OMNI_BOOLEAN: + return createByteArrayBlock(dictionary, dictionary.getSize()); + case OMNI_INT: + case OMNI_DATE32: + return createIntArrayBlock(dictionary, dictionary.getSize()); + case OMNI_SHORT: + return createShortArrayBlock(dictionary, dictionary.getSize()); + case OMNI_LONG: + case OMNI_DECIMAL64: + return createLongArrayBlock(dictionary, dictionary.getSize()); + case OMNI_DOUBLE: + return createDoubleArrayBlock(dictionary, dictionary.getSize()); + case OMNI_VARCHAR: + case OMNI_CHAR: + return createVariableWidthBlock(dictionary, dictionary.getSize()); + case OMNI_DECIMAL128: + return createInt128ArrayBlock(dictionary, dictionary.getSize()); + default: + throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Not support Type " + dataTypeId); + } + } + + public static Block expandDictionary(DictionaryVec dictionaryVec) + { + Vec vec = expandDictionaryVec(dictionaryVec); + DataType dataType = vec.getType(); + Block dictionaryBlock = createFlatBlock(dataType.getId(), vec); + vec.close(); + return dictionaryBlock; + } + + public static int[] getIds(int positionCount) + { + int[] ids = new int[positionCount]; + for (int i = 0; i < positionCount; i++) { + ids[i] = i; + } + return ids; + } + + /** + * DictionaryVec transfer to Vec + * + * @return vector + */ + public static Vec expandDictionaryVec(DictionaryVec dictionaryVec) + { + return dictionaryVec.expandDictionary(); + } + @Override public Vec getValues() { @@ -485,15 +611,14 @@ public class DictionaryOmniBlock int position = positions[offset + i]; newIds[i] = getId(position); } - return new DictionaryOmniBlock((Vec) dictionary.getValues(), newIds); + return new DictionaryBlock(dictionary, newIds); } @Override public Block getRegion(int positionOffset, int length) { checkValidRegion(positionCount, positionOffset, length); - return new DictionaryOmniBlock(idsOffset + positionOffset, length, dictionaryVec, ids, false, - dictionarySourceId); + return new DictionaryBlock(idsOffset + positionOffset, length, dictionary, ids, false, dictionarySourceId); } @Override @@ -501,7 +626,7 @@ public class DictionaryOmniBlock { checkValidRegion(positionCount, position, length); int[] newIds = Arrays.copyOfRange(ids, idsOffset + position, idsOffset + position + length); - DictionaryOmniBlock dictionaryBlock = new DictionaryOmniBlock((Vec) dictionary.getValues(), newIds); + DictionaryBlock dictionaryBlock = new DictionaryBlock(dictionary, newIds); return dictionaryBlock.compact(); } @@ -532,8 +657,7 @@ public class DictionaryOmniBlock for (int i = 0; i < dictionary.getPositionCount() && isCompact; i++) { isCompact &= seen[i]; } - return new DictionaryOmniBlock(newIds.length, (Vec) dictionary.getValues(), newIds, isCompact, - getDictionarySourceId()); + return new DictionaryBlock(newIds.length, dictionary, newIds, isCompact, getDictionarySourceId()); } @Override @@ -553,8 +677,8 @@ public class DictionaryOmniBlock if (loadedDictionary == dictionary) { return this; } - return new DictionaryOmniBlock(idsOffset, getPositionCount(), (Vec) loadedDictionary.getValues(), ids, false, - randomDictionaryId()); + + return new DictionaryBlock(idsOffset, getPositionCount(), loadedDictionary, ids, false, randomDictionaryId()); } /** @@ -572,9 +696,9 @@ public class DictionaryOmniBlock * * @return the ids */ - Slice getIds() + public int[] getIds() { - return Slices.wrappedIntArray(ids, idsOffset, positionCount); + return ids; } /** diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/DoubleArrayOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/DoubleArrayOmniBlock.java index 4056c15b99686a568521b9b6387c5eeed0594932..6e78c376cbde6b01a90069bfa2abdfab2aea081b 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/DoubleArrayOmniBlock.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/DoubleArrayOmniBlock.java @@ -20,21 +20,15 @@ import io.prestosql.spi.block.BlockBuilder; import io.prestosql.spi.block.LongArrayBlockEncoding; import nova.hetu.omniruntime.vector.DoubleVec; import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecAllocator; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - -import java.util.Optional; import java.util.function.BiConsumer; import static io.airlift.slice.SizeOf.sizeOf; import static io.prestosql.spi.block.BlockUtil.checkArrayRange; import static io.prestosql.spi.block.BlockUtil.checkValidRegion; -import static io.prestosql.spi.block.BlockUtil.compactArray; import static io.prestosql.spi.block.BlockUtil.countUsedPositions; import static java.lang.Double.doubleToLongBits; -import static nova.hetu.olk.tool.BlockUtils.compactVec; /** * The type Double array omni block. @@ -46,46 +40,15 @@ public class DoubleArrayOmniBlock { private static final int INSTANCE_SIZE = ClassLayout.parseClass(DoubleArrayOmniBlock.class).instanceSize(); - private final VecAllocator vecAllocator; - - private final int arrayOffset; - private final int positionCount; - @Nullable - private final byte[] valueIsNull; - private final DoubleVec values; private final long sizeInBytes; private final long retainedSizeInBytes; - /** - * Instantiates a new Double array omni block. - * - * @param vecAllocator vector allocator - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - public DoubleArrayOmniBlock(VecAllocator vecAllocator, int positionCount, Optional valueIsNull, - double[] values) - { - this(vecAllocator, 0, positionCount, valueIsNull.orElse(null), values); - } - - /** - * Instantiates a new Double array omni block. - * - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - public DoubleArrayOmniBlock(int positionCount, Optional valueIsNull, DoubleVec values) - { - this(values.getOffset(), positionCount, valueIsNull.orElse(null), values); - } + private boolean hasNull; /** * Instantiates a new Double array omni block. @@ -95,22 +58,24 @@ public class DoubleArrayOmniBlock */ public DoubleArrayOmniBlock(int positionCount, DoubleVec values) { - this(positionCount, values.hasNullValue() ? Optional.of(values.getRawValueNulls()) : Optional.empty(), values); + this.positionCount = positionCount; + this.values = values; + this.sizeInBytes = (Double.BYTES + Byte.BYTES) * (long) positionCount; + this.retainedSizeInBytes = INSTANCE_SIZE + this.values.getCapacityInBytes(); + this.hasNull = values.hasNull(); } /** * Instantiates a new Double array omni block. * - * @param vecAllocator vector allocator * @param arrayOffset the array offset * @param positionCount the position count * @param valueIsNull the value is null * @param values the values */ - public DoubleArrayOmniBlock(VecAllocator vecAllocator, int arrayOffset, int positionCount, byte[] valueIsNull, + public DoubleArrayOmniBlock(int arrayOffset, int positionCount, byte[] valueIsNull, double[] values) { - this.vecAllocator = vecAllocator; if (arrayOffset < 0) { throw new IllegalArgumentException("arrayOffset is negative"); } @@ -123,7 +88,7 @@ public class DoubleArrayOmniBlock throw new IllegalArgumentException("values length is less than positionCount"); } - this.values = new DoubleVec(vecAllocator, positionCount); + this.values = new DoubleVec(positionCount); this.values.put(values, 0, arrayOffset, positionCount); if (valueIsNull != null && valueIsNull.length - arrayOffset < positionCount) { @@ -132,48 +97,8 @@ public class DoubleArrayOmniBlock if (valueIsNull != null) { this.values.setNulls(0, valueIsNull, arrayOffset, positionCount); - this.valueIsNull = compactArray(valueIsNull, arrayOffset, positionCount); - } - else { - this.valueIsNull = null; - } - - this.arrayOffset = 0; - - sizeInBytes = (Double.BYTES + Byte.BYTES) * (long) positionCount; - retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + this.values.getCapacityInBytes(); - } - - /** - * Instantiates a new Double array omni block. - * - * @param arrayOffset the array offset - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - DoubleArrayOmniBlock(int arrayOffset, int positionCount, byte[] valueIsNull, DoubleVec values) - { - this.vecAllocator = values.getAllocator(); - if (arrayOffset < 0) { - throw new IllegalArgumentException("arrayOffset is negative"); - } - this.arrayOffset = arrayOffset; - - if (positionCount < 0) { - throw new IllegalArgumentException("positionCount is negative"); + this.hasNull = true; } - this.positionCount = positionCount; - - if (values.getSize() < positionCount) { - throw new IllegalArgumentException("values length is less than positionCount"); - } - this.values = values; - - if (valueIsNull != null && valueIsNull.length - arrayOffset < positionCount) { - throw new IllegalArgumentException("isNull length is less than positionCount"); - } - this.valueIsNull = valueIsNull; sizeInBytes = (Double.BYTES + Byte.BYTES) * (long) positionCount; retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + this.values.getCapacityInBytes(); @@ -225,9 +150,6 @@ public class DoubleArrayOmniBlock public void retainedBytesForEachPart(BiConsumer consumer) { consumer.accept(values.get(0, positionCount), (long) values.getCapacityInBytes()); - if (valueIsNull != null) { - consumer.accept(valueIsNull, sizeOf(valueIsNull)); - } consumer.accept(this, (long) INSTANCE_SIZE); } @@ -268,14 +190,14 @@ public class DoubleArrayOmniBlock @Override public boolean mayHaveNull() { - return valueIsNull != null; + return values.hasNull(); } @Override public boolean isNull(int position) { checkReadablePosition(position); - return valueIsNull != null && valueIsNull[position] == Vec.NULL; + return values.isNull(position); } @Override @@ -290,7 +212,7 @@ public class DoubleArrayOmniBlock public Block getSingleValueBlock(int position) { checkReadablePosition(position); - return new DoubleArrayOmniBlock(vecAllocator, 0, 1, isNull(position) ? new byte[]{Vec.NULL} : null, + return new DoubleArrayOmniBlock(0, 1, isNull(position) ? new byte[]{Vec.NULL} : null, new double[]{values.get(position)}); } @@ -298,36 +220,25 @@ public class DoubleArrayOmniBlock public Block copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); - byte[] newValueIsNull = null; DoubleVec newValues = values.copyPositions(positions, offset, length); - if (valueIsNull != null) { - newValueIsNull = newValues.getRawValueNulls(); - } - return new DoubleArrayOmniBlock(0, length, newValueIsNull, newValues); + return new DoubleArrayOmniBlock(length, newValues); } @Override public Block getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); - DoubleVec newValues = values.slice(positionOffset, positionOffset + length); - return new DoubleArrayOmniBlock(newValues.getOffset(), length, valueIsNull, newValues); + DoubleVec newValues = values.slice(positionOffset, length); + return new DoubleArrayOmniBlock(length, newValues); } @Override public Block copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); - - DoubleVec newValues = compactVec(values, positionOffset, length); - byte[] newValueIsNull = valueIsNull == null - ? null - : compactArray(valueIsNull, positionOffset + arrayOffset, length); - - if (newValueIsNull == valueIsNull && newValues == values) { - return this; - } - return new DoubleArrayOmniBlock(0, length, newValueIsNull, newValues); + DoubleVec newValues = values.slice(positionOffset, length); + values.close(); + return new DoubleArrayOmniBlock(length, newValues); } @Override @@ -355,7 +266,7 @@ public class DoubleArrayOmniBlock @Override public Double get(int position) { - if (valueIsNull != null && valueIsNull[position + arrayOffset] == Vec.NULL) { + if (values.isNull(position)) { return null; } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/Int128ArrayOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/Int128ArrayOmniBlock.java index b50ab1d0333d024e4507cbcd27cb1e8d2569b65e..c1ee727cffe7e98c9bebdab5a4b5e515f801410c 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/Int128ArrayOmniBlock.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/Int128ArrayOmniBlock.java @@ -23,21 +23,15 @@ import io.prestosql.spi.block.Int128ArrayBlockEncoding; import io.prestosql.spi.util.BloomFilter; import nova.hetu.omniruntime.vector.Decimal128Vec; import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecAllocator; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - -import java.util.Optional; import java.util.function.BiConsumer; import java.util.function.Function; import static io.airlift.slice.SizeOf.sizeOf; import static io.prestosql.spi.block.BlockUtil.checkArrayRange; import static io.prestosql.spi.block.BlockUtil.checkValidRegion; -import static io.prestosql.spi.block.BlockUtil.compactArray; import static io.prestosql.spi.block.BlockUtil.countUsedPositions; -import static nova.hetu.olk.tool.BlockUtils.compactVec; /** * The type Int 128 array omni block. @@ -54,34 +48,15 @@ public class Int128ArrayOmniBlock */ public static final int INT128_BYTES = Long.BYTES + Long.BYTES; - private final VecAllocator vecAllocator; - - private final int positionOffset; - private final int positionCount; - @Nullable - private final byte[] valueIsNull; - private final Decimal128Vec values; private final long sizeInBytes; private final long retainedSizeInBytes; - /** - * Instantiates a new Int 128 array omni block. - * - * @param vecAllocator vector allocator - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - public Int128ArrayOmniBlock(VecAllocator vecAllocator, int positionCount, Optional valueIsNull, - long[] values) - { - this(vecAllocator, 0, positionCount, valueIsNull.orElse(null), values); - } + private boolean hasNull; /** * Instantiates a new Int 128 array omni block. @@ -91,34 +66,24 @@ public class Int128ArrayOmniBlock */ public Int128ArrayOmniBlock(int positionCount, Decimal128Vec values) { - this(positionCount, values.hasNullValue() ? Optional.of(values.getRawValueNulls()) : Optional.empty(), values); - } - - /** - * Instantiates a new Int 128 array omni block. - * - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - public Int128ArrayOmniBlock(int positionCount, Optional valueIsNull, Decimal128Vec values) - { - this(values.getOffset(), positionCount, valueIsNull.orElse(null), values); + this.positionCount = positionCount; + this.values = values; + this.sizeInBytes = (INT128_BYTES + Byte.BYTES) * (long) positionCount; + this.retainedSizeInBytes = INSTANCE_SIZE + this.values.getCapacityInBytes(); + this.hasNull = values.hasNull(); } /** * Instantiates a new Int 128 array omni block. * - * @param vecAllocator vector allocator * @param positionOffset the position offset * @param positionCount the position count * @param valueIsNull the value is null * @param values the values */ - public Int128ArrayOmniBlock(VecAllocator vecAllocator, int positionOffset, int positionCount, byte[] valueIsNull, + public Int128ArrayOmniBlock(int positionOffset, int positionCount, byte[] valueIsNull, long[] values) { - this.vecAllocator = vecAllocator; if (positionOffset < 0) { throw new IllegalArgumentException("positionOffset is negative"); } @@ -131,7 +96,7 @@ public class Int128ArrayOmniBlock throw new IllegalArgumentException("values length is less than positionCount"); } - this.values = new Decimal128Vec(vecAllocator, positionCount); + this.values = new Decimal128Vec(positionCount); this.values.put(values, 0, positionOffset * 2, positionCount * 2); if (valueIsNull != null && valueIsNull.length - positionOffset < positionCount) { @@ -140,49 +105,8 @@ public class Int128ArrayOmniBlock if (valueIsNull != null) { this.values.setNulls(0, valueIsNull, positionOffset, positionCount); - this.valueIsNull = compactArray(valueIsNull, positionOffset, positionCount); - } - else { - this.valueIsNull = null; + this.hasNull = true; } - - this.positionOffset = 0; - - sizeInBytes = (INT128_BYTES + Byte.BYTES) * (long) positionCount; - retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + this.values.getCapacityInBytes(); - } - - /** - * Instantiates a new Int 128 array omni block. - * - * @param positionOffset the position offset - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - Int128ArrayOmniBlock(int positionOffset, int positionCount, byte[] valueIsNull, Decimal128Vec values) - { - this.vecAllocator = values.getAllocator(); - if (positionOffset < 0) { - throw new IllegalArgumentException("positionOffset is negative"); - } - this.positionOffset = positionOffset; - - if (positionCount < 0) { - throw new IllegalArgumentException("positionCount is negative"); - } - this.positionCount = positionCount; - - if (values.getSize() < positionCount) { - throw new IllegalArgumentException("values length is less than positionCount"); - } - this.values = values; - - if (valueIsNull != null && valueIsNull.length - positionOffset < positionCount) { - throw new IllegalArgumentException("isNull length is less than positionCount"); - } - this.valueIsNull = valueIsNull; - sizeInBytes = (INT128_BYTES + Byte.BYTES) * (long) positionCount; retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + this.values.getCapacityInBytes(); } @@ -221,9 +145,6 @@ public class Int128ArrayOmniBlock public void retainedBytesForEachPart(BiConsumer consumer) { consumer.accept(values, (long) values.getCapacityInBytes()); - if (valueIsNull != null) { - consumer.accept(valueIsNull, sizeOf(valueIsNull)); - } consumer.accept(this, (long) INSTANCE_SIZE); } @@ -249,14 +170,14 @@ public class Int128ArrayOmniBlock @Override public boolean mayHaveNull() { - return valueIsNull != null; + return hasNull; } @Override public boolean isNull(int position) { checkReadablePosition(position); - return valueIsNull != null && valueIsNull[position + positionOffset] == Vec.NULL; + return values.isNull(position); } @Override @@ -272,7 +193,7 @@ public class Int128ArrayOmniBlock public Block getSingleValueBlock(int position) { checkReadablePosition(position); - return new Int128ArrayOmniBlock(vecAllocator, 0, 1, isNull(position) ? new byte[]{Vec.NULL} : null, + return new Int128ArrayOmniBlock(0, 1, isNull(position) ? new byte[]{Vec.NULL} : null, values.get(position)); } @@ -280,20 +201,16 @@ public class Int128ArrayOmniBlock public Block copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); - byte[] newValueIsNull = null; Decimal128Vec newValues = values.copyPositions(positions, offset, length); - if (valueIsNull != null) { - newValueIsNull = newValues.getRawValueNulls(); - } - return new Int128ArrayOmniBlock(0, length, newValueIsNull, newValues); + return new Int128ArrayOmniBlock(length, newValues); } @Override public Block getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); - Decimal128Vec newValues = values.slice(positionOffset, positionOffset + length); - return new Int128ArrayOmniBlock(newValues.getOffset(), length, valueIsNull, newValues); + Decimal128Vec newValues = values.slice(positionOffset, length); + return new Int128ArrayOmniBlock(length, newValues); } @Override @@ -301,14 +218,9 @@ public class Int128ArrayOmniBlock { checkValidRegion(getPositionCount(), positionOffset, length); - Decimal128Vec newValues = compactVec(values, positionOffset, length); - byte[] newValueIsNull = valueIsNull == null - ? null - : compactArray(valueIsNull, positionOffset + positionOffset, length); - if (newValueIsNull == valueIsNull && newValues == values) { - return this; - } - return new Int128ArrayOmniBlock(0, length, newValueIsNull, newValues); + Decimal128Vec newValues = values.slice(positionOffset, length); + values.close(); + return new Int128ArrayOmniBlock(length, newValues); } @Override @@ -349,7 +261,7 @@ public class Int128ArrayOmniBlock int matchCount = 0; long[] val; for (int i = 0; i < positionCount; i++) { - if (valueIsNull != null && valueIsNull[positions[i] + positionOffset] == Vec.NULL) { + if (values.isNull(positions[i])) { if (test.apply(null)) { matchedPositions[matchCount++] = positions[i]; } @@ -368,7 +280,7 @@ public class Int128ArrayOmniBlock @Override public long[] get(int position) { - if (valueIsNull != null && valueIsNull[position + positionOffset] == Vec.NULL) { + if (values.isNull(position)) { return null; } return values.get(position); diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/IntArrayOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/IntArrayOmniBlock.java index 878b9deed88bc882f60b7cd6e4f3c671b7de98cd..35a6c20821a15d426e770aec7e059156f5b34a4d 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/IntArrayOmniBlock.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/IntArrayOmniBlock.java @@ -21,21 +21,14 @@ import io.prestosql.spi.block.IntArrayBlockEncoding; import io.prestosql.spi.util.BloomFilter; import nova.hetu.omniruntime.vector.IntVec; import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecAllocator; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - -import java.util.Optional; import java.util.function.BiConsumer; import java.util.function.Function; -import static io.airlift.slice.SizeOf.sizeOf; import static io.prestosql.spi.block.BlockUtil.checkArrayRange; import static io.prestosql.spi.block.BlockUtil.checkValidRegion; -import static io.prestosql.spi.block.BlockUtil.compactArray; import static io.prestosql.spi.block.BlockUtil.countUsedPositions; -import static nova.hetu.olk.tool.BlockUtils.compactVec; /** * The type Int array omni block. @@ -47,33 +40,15 @@ public class IntArrayOmniBlock { private static final int INSTANCE_SIZE = ClassLayout.parseClass(IntArrayOmniBlock.class).instanceSize(); - private final VecAllocator vecAllocator; - - private final int arrayOffset; - private final int positionCount; - @Nullable - private final byte[] valueIsNull; - private final IntVec values; private final long sizeInBytes; private final long retainedSizeInBytes; - /** - * Instantiates a new Int array omni block. - * - * @param vecAllocator the vector allocator - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - public IntArrayOmniBlock(VecAllocator vecAllocator, int positionCount, Optional valueIsNull, int[] values) - { - this(vecAllocator, 0, positionCount, valueIsNull.orElse(null), values); - } + private boolean hasNull; /** * Instantiates a new Int array omni block. @@ -83,33 +58,23 @@ public class IntArrayOmniBlock */ public IntArrayOmniBlock(int positionCount, IntVec values) { - this(positionCount, values.hasNullValue() ? Optional.of(values.getRawValueNulls()) : Optional.empty(), values); - } - - /** - * Instantiates a new Int array omni block. - * - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - public IntArrayOmniBlock(int positionCount, Optional valueIsNull, IntVec values) - { - this(values.getOffset(), positionCount, valueIsNull.orElse(null), values); + this.positionCount = positionCount; + this.values = values; + this.sizeInBytes = (Integer.BYTES + Byte.BYTES) * (long) positionCount; + this.retainedSizeInBytes = INSTANCE_SIZE + this.values.getCapacityInBytes(); + this.hasNull = values.hasNull(); } /** * Instantiates a new Int array omni block. * - * @param vecAllocator the vector allocator * @param arrayOffset the array offset * @param positionCount the position count * @param valueIsNull the value is null * @param values the values */ - public IntArrayOmniBlock(VecAllocator vecAllocator, int arrayOffset, int positionCount, byte[] valueIsNull, int[] values) + public IntArrayOmniBlock(int arrayOffset, int positionCount, byte[] valueIsNull, int[] values) { - this.vecAllocator = vecAllocator; if (arrayOffset < 0) { throw new IllegalArgumentException("arrayOffset is negative"); } @@ -123,7 +88,7 @@ public class IntArrayOmniBlock throw new IllegalArgumentException("values length is less than positionCount"); } - this.values = new IntVec(vecAllocator, positionCount); + this.values = new IntVec(positionCount); this.values.put(values, 0, arrayOffset, positionCount); if (valueIsNull != null && valueIsNull.length - arrayOffset < positionCount) { @@ -132,51 +97,11 @@ public class IntArrayOmniBlock if (valueIsNull != null) { this.values.setNulls(0, valueIsNull, arrayOffset, positionCount); - this.valueIsNull = compactArray(valueIsNull, arrayOffset, positionCount); - } - else { - this.valueIsNull = null; + this.hasNull = true; } - this.arrayOffset = 0; - sizeInBytes = (Integer.BYTES + Byte.BYTES) * (long) positionCount; - retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + this.values.getCapacityInBytes(); - } - - /** - * Instantiates a new Int array omni block. - * - * @param arrayOffset the array offset - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - IntArrayOmniBlock(int arrayOffset, int positionCount, byte[] valueIsNull, IntVec values) - { - this.vecAllocator = values.getAllocator(); - if (arrayOffset < 0) { - throw new IllegalArgumentException("arrayOffset is negative"); - } - this.arrayOffset = arrayOffset; - - if (positionCount < 0) { - throw new IllegalArgumentException("positionCount is negative"); - } - this.positionCount = positionCount; - - if (values.getSize() < positionCount) { - throw new IllegalArgumentException("values length is less than positionCount"); - } - this.values = values; - - if (valueIsNull != null && valueIsNull.length - arrayOffset < positionCount) { - throw new IllegalArgumentException("isNull length is less than positionCount"); - } - this.valueIsNull = valueIsNull; - - sizeInBytes = (Integer.BYTES + Byte.BYTES) * (long) positionCount; - retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + this.values.getCapacityInBytes(); + retainedSizeInBytes = INSTANCE_SIZE + this.values.getCapacityInBytes(); } @Override @@ -225,9 +150,6 @@ public class IntArrayOmniBlock public void retainedBytesForEachPart(BiConsumer consumer) { consumer.accept(values.get(0, positionCount), (long) values.getCapacityInBytes()); - if (valueIsNull != null) { - consumer.accept(valueIsNull, sizeOf(valueIsNull)); - } consumer.accept(this, (long) INSTANCE_SIZE); } @@ -268,14 +190,14 @@ public class IntArrayOmniBlock @Override public boolean mayHaveNull() { - return valueIsNull != null; + return hasNull; } @Override public boolean isNull(int position) { checkReadablePosition(position); - return valueIsNull != null && valueIsNull[position + arrayOffset] == Vec.NULL; + return values.isNull(position); } @Override @@ -290,7 +212,7 @@ public class IntArrayOmniBlock public Block getSingleValueBlock(int position) { checkReadablePosition(position); - return new IntArrayOmniBlock(vecAllocator, 0, 1, isNull(position) ? new byte[]{Vec.NULL} : null, + return new IntArrayOmniBlock(0, 1, isNull(position) ? new byte[]{Vec.NULL} : null, new int[]{values.get(position)}); } @@ -298,20 +220,16 @@ public class IntArrayOmniBlock public Block copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); - byte[] newValueIsNull = null; IntVec newValues = values.copyPositions(positions, offset, length); - if (valueIsNull != null) { - newValueIsNull = newValues.getRawValueNulls(); - } - return new IntArrayOmniBlock(0, length, newValueIsNull, newValues); + return new IntArrayOmniBlock(length, newValues); } @Override public Block getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); - IntVec newValues = values.slice(positionOffset, positionOffset + length); - return new IntArrayOmniBlock(newValues.getOffset(), length, valueIsNull, newValues); + IntVec newValues = values.slice(positionOffset, length); + return new IntArrayOmniBlock(length, newValues); } @Override @@ -319,15 +237,9 @@ public class IntArrayOmniBlock { checkValidRegion(getPositionCount(), positionOffset, length); - IntVec newValues = compactVec(values, positionOffset, length); - byte[] newValueIsNull = valueIsNull == null - ? null - : compactArray(valueIsNull, positionOffset + arrayOffset, length); - - if (newValueIsNull == valueIsNull && newValues == values) { - return this; - } - return new IntArrayOmniBlock(0, length, newValueIsNull, newValues); + IntVec newValues = values.slice(positionOffset, length); + values.close(); + return new IntArrayOmniBlock(length, newValues); } @Override @@ -366,7 +278,7 @@ public class IntArrayOmniBlock { int matchCount = 0; for (int i = 0; i < positionCount; i++) { - if (valueIsNull != null && valueIsNull[positions[i] + arrayOffset] == Vec.NULL) { + if (values.isNull(positions[i])) { if (test.apply(null)) { matchedPositions[matchCount++] = positions[i]; } @@ -382,7 +294,7 @@ public class IntArrayOmniBlock @Override public Integer get(int position) { - if (valueIsNull != null && valueIsNull[position + arrayOffset] == Vec.NULL) { + if (values.isNull(position)) { return null; } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/LazyOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/LazyOmniBlock.java index 10595ea694a7ac32deff91ef882e86c2fc75bce6..d62bea6c896f3ea9de301e0112ae557ace2becf3 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/LazyOmniBlock.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/LazyOmniBlock.java @@ -19,10 +19,6 @@ import io.prestosql.spi.block.Block; import io.prestosql.spi.block.BlockBuilder; import io.prestosql.spi.block.LazyBlock; import io.prestosql.spi.type.Type; -import nova.hetu.olk.tool.OperatorUtils; -import nova.hetu.omniruntime.vector.LazyVec; -import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecAllocator; import java.util.function.BiConsumer; @@ -36,16 +32,9 @@ public class LazyOmniBlock { private LazyBlock lazyBlock; - private final LazyVec nativeLazyVec; - - public LazyOmniBlock(VecAllocator vecAllocator, LazyBlock lazyBlock, Type blockType) + public LazyOmniBlock(LazyBlock lazyBlock, Type blockType) { this.lazyBlock = lazyBlock; - nativeLazyVec = new LazyVec(vecAllocator, lazyBlock.getPositionCount(), () -> { - Block block = lazyBlock.getLoadedBlock(); - return (Vec) OperatorUtils.buildOffHeapBlock(vecAllocator, block, block.getClass().getSimpleName(), - block.getPositionCount(), blockType).getValues(); - }); } @Override @@ -54,12 +43,6 @@ public class LazyOmniBlock return true; } - @Override - public Object getValues() - { - return nativeLazyVec; - } - @Override public void writePositionTo(int position, BlockBuilder blockBuilder) { @@ -154,10 +137,4 @@ public class LazyOmniBlock { return lazyBlock; } - - @Override - public void close() - { - nativeLazyVec.close(); - } } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/LongArrayOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/LongArrayOmniBlock.java index c336942cc3b2f7a60a7f88378bb8925f8793c5d1..f792042f6b548fb5c22349b2915c0046168a95a5 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/LongArrayOmniBlock.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/LongArrayOmniBlock.java @@ -21,22 +21,15 @@ import io.prestosql.spi.block.LongArrayBlockEncoding; import io.prestosql.spi.util.BloomFilter; import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecAllocator; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - -import java.util.Optional; import java.util.function.BiConsumer; import java.util.function.Function; -import static io.airlift.slice.SizeOf.sizeOf; import static io.prestosql.spi.block.BlockUtil.checkArrayRange; import static io.prestosql.spi.block.BlockUtil.checkValidRegion; -import static io.prestosql.spi.block.BlockUtil.compactArray; import static io.prestosql.spi.block.BlockUtil.countUsedPositions; import static java.lang.Math.toIntExact; -import static nova.hetu.olk.tool.BlockUtils.compactVec; /** * The type Long array omni block. @@ -48,34 +41,15 @@ public class LongArrayOmniBlock { private static final int INSTANCE_SIZE = ClassLayout.parseClass(LongArrayOmniBlock.class).instanceSize(); - private final VecAllocator vecAllocator; - - private final int arrayOffset; - private final int positionCount; - @Nullable - private final byte[] valueIsNull; - private final LongVec values; private final long sizeInBytes; private final long retainedSizeInBytes; - /** - * Instantiates a new Long array omni block. - * - * @param vecAllocator vector allocator - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - public LongArrayOmniBlock(VecAllocator vecAllocator, int positionCount, Optional valueIsNull, - long[] values) - { - this(vecAllocator, 0, positionCount, valueIsNull.orElse(null), values); - } + private boolean hasNull; /** * Instantiates a new Long array omni block. @@ -85,34 +59,24 @@ public class LongArrayOmniBlock */ public LongArrayOmniBlock(int positionCount, LongVec values) { - this(positionCount, values.hasNullValue() ? Optional.of(values.getRawValueNulls()) : Optional.empty(), values); - } - - /** - * Instantiates a new Long array omni block. - * - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - public LongArrayOmniBlock(int positionCount, Optional valueIsNull, LongVec values) - { - this(values.getOffset(), positionCount, valueIsNull.orElse(null), values); + this.positionCount = positionCount; + this.values = values; + this.sizeInBytes = (Long.BYTES + Byte.BYTES) * (long) positionCount; + this.retainedSizeInBytes = INSTANCE_SIZE + this.values.getCapacityInBytes(); + this.hasNull = values.hasNull(); } /** * Instantiates a new Long array omni block. * - * @param vecAllocator vector allocator * @param arrayOffset the array offset * @param positionCount the position count * @param valueIsNull the value is null * @param values the values */ - public LongArrayOmniBlock(VecAllocator vecAllocator, int arrayOffset, int positionCount, byte[] valueIsNull, + public LongArrayOmniBlock(int arrayOffset, int positionCount, byte[] valueIsNull, long[] values) { - this.vecAllocator = vecAllocator; if (arrayOffset < 0) { throw new IllegalArgumentException("arrayOffset is negative"); } @@ -125,7 +89,7 @@ public class LongArrayOmniBlock throw new IllegalArgumentException("values length is less than positionCount"); } - this.values = new LongVec(vecAllocator, positionCount); + this.values = new LongVec(positionCount); this.values.put(values, 0, arrayOffset, positionCount); if (valueIsNull != null && valueIsNull.length - arrayOffset < positionCount) { @@ -134,51 +98,11 @@ public class LongArrayOmniBlock if (valueIsNull != null) { this.values.setNulls(0, valueIsNull, arrayOffset, positionCount); - this.valueIsNull = compactArray(valueIsNull, arrayOffset, positionCount); - } - else { - this.valueIsNull = null; + this.hasNull = true; } - this.arrayOffset = 0; - sizeInBytes = (Long.BYTES + Byte.BYTES) * (long) positionCount; - retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + this.values.getCapacityInBytes(); - } - - /** - * Instantiates a new Long array omni block. - * - * @param arrayOffset the array offset - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - public LongArrayOmniBlock(int arrayOffset, int positionCount, byte[] valueIsNull, LongVec values) - { - vecAllocator = values.getAllocator(); - if (arrayOffset < 0) { - throw new IllegalArgumentException("arrayOffset is negative"); - } - this.arrayOffset = arrayOffset; - - if (positionCount < 0) { - throw new IllegalArgumentException("positionCount is negative"); - } - this.positionCount = positionCount; - - if (values.getSize() < positionCount) { - throw new IllegalArgumentException("values length is less than positionCount"); - } - this.values = values; - - if (valueIsNull != null && valueIsNull.length - arrayOffset < positionCount) { - throw new IllegalArgumentException("isNull length is less than positionCount"); - } - this.valueIsNull = valueIsNull; - - sizeInBytes = (Long.BYTES + Byte.BYTES) * (long) positionCount; - retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + this.values.getCapacityInBytes(); + retainedSizeInBytes = INSTANCE_SIZE + this.values.getCapacityInBytes(); } @Override @@ -227,9 +151,6 @@ public class LongArrayOmniBlock public void retainedBytesForEachPart(BiConsumer consumer) { consumer.accept(values.get(0, positionCount), (long) values.getCapacityInBytes()); - if (valueIsNull != null) { - consumer.accept(valueIsNull, sizeOf(valueIsNull)); - } consumer.accept(this, (long) INSTANCE_SIZE); } @@ -263,7 +184,7 @@ public class LongArrayOmniBlock public Long get(int position) { - if (valueIsNull != null && valueIsNull[position + arrayOffset] == Vec.NULL) { + if (values.isNull(position)) { return null; } return values.get(position); @@ -284,14 +205,14 @@ public class LongArrayOmniBlock @Override public boolean mayHaveNull() { - return valueIsNull != null; + return hasNull; } @Override public boolean isNull(int position) { checkReadablePosition(position); - return valueIsNull != null && valueIsNull[position + arrayOffset] == Vec.NULL; + return values.isNull(position); } @Override @@ -306,7 +227,7 @@ public class LongArrayOmniBlock public Block getSingleValueBlock(int position) { checkReadablePosition(position); - return new LongArrayOmniBlock(vecAllocator, 0, 1, isNull(position) ? new byte[]{Vec.NULL} : null, + return new LongArrayOmniBlock(0, 1, isNull(position) ? new byte[]{Vec.NULL} : null, new long[]{values.get(position)}); } @@ -314,20 +235,16 @@ public class LongArrayOmniBlock public Block copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); - byte[] newValueIsNull = null; LongVec newValues = values.copyPositions(positions, offset, length); - if (valueIsNull != null) { - newValueIsNull = newValues.getRawValueNulls(); - } - return new LongArrayOmniBlock(0, length, newValueIsNull, newValues); + return new LongArrayOmniBlock(length, newValues); } @Override public Block getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); - LongVec newValues = values.slice(positionOffset, positionOffset + length); - return new LongArrayOmniBlock(newValues.getOffset(), length, valueIsNull, newValues); + LongVec newValues = values.slice(positionOffset, length); + return new LongArrayOmniBlock(length, newValues); } @Override @@ -335,16 +252,9 @@ public class LongArrayOmniBlock { checkValidRegion(getPositionCount(), positionOffset, length); - LongVec newValues = compactVec(values, positionOffset, length); - byte[] newValueIsNull = valueIsNull == null - ? null - : compactArray(valueIsNull, positionOffset + arrayOffset, length); - - if (newValueIsNull == valueIsNull && newValues == values) { - return this; - } - - return new LongArrayOmniBlock(0, length, newValueIsNull, newValues); + LongVec newValues = values.slice(positionOffset, length); + values.close(); + return new LongArrayOmniBlock(length, newValues); } @Override @@ -384,7 +294,7 @@ public class LongArrayOmniBlock { int matchCount = 0; for (int i = 0; i < positionCount; i++) { - if (valueIsNull != null && valueIsNull[positions[i] + arrayOffset] == Vec.NULL) { + if (values.isNull(positions[i])) { if (test.apply(null)) { matchedPositions[matchCount++] = positions[i]; } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/RowOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/RowOmniBlock.java index 1e07b51e2c5e41d6b0546f8276881a5ee5084714..013efef4537c40990c2db70af7e91209a7de9c09 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/RowOmniBlock.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/RowOmniBlock.java @@ -24,7 +24,6 @@ import nova.hetu.omniruntime.type.ContainerDataType; import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.vector.ContainerVec; import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecAllocator; import org.openjdk.jol.info.ClassLayout; import javax.annotation.Nullable; @@ -36,7 +35,6 @@ import static io.airlift.slice.SizeOf.sizeOf; import static io.prestosql.spi.block.BlockUtil.checkArrayRange; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -import static nova.hetu.olk.tool.VecAllocatorHelper.getVecAllocatorFromBlocks; /** * The type Row omni block. @@ -49,8 +47,6 @@ public class RowOmniBlock { private static final int INSTANCE_SIZE = ClassLayout.parseClass(RowOmniBlock.class).instanceSize(); - private final VecAllocator vecAllocator; - private final int startOffset; private final int positionCount; @@ -80,7 +76,7 @@ public class RowOmniBlock * @param fieldBlocks the field blocks * @return the block */ - public static Block fromFieldBlocks(VecAllocator vecAllocator, int positionCount, Optional rowIsNull, + public static Block fromFieldBlocks(int positionCount, Optional rowIsNull, Block[] fieldBlocks, Type blockType, ContainerVec containerVec) { int[] fieldBlockOffsets = new int[positionCount + 1]; @@ -93,8 +89,8 @@ public class RowOmniBlock Block[] newOffHeapFieldBlocks = new Block[fieldBlocks.length]; for (int blockIndex = 0; blockIndex < fieldBlocks.length; ++blockIndex) { Block block = fieldBlocks[blockIndex]; - newOffHeapFieldBlocks[blockIndex] = OperatorUtils.buildOffHeapBlock(vecAllocator, block, - block.getClass().getSimpleName(), block.getPositionCount(), + newOffHeapFieldBlocks[blockIndex] = OperatorUtils.buildOffHeapBlock(block, block.getClass().getSimpleName(), + block.getPositionCount(), blockType == null ? null : blockType.getTypeParameters().get(blockIndex)); } return new RowOmniBlock(0, positionCount, rowIsNull.orElse(null), fieldBlockOffsets, newOffHeapFieldBlocks, @@ -112,7 +108,7 @@ public class RowOmniBlock * @param dataType data type of block * @return the row omni block */ - static RowOmniBlock createRowBlockInternal(int startOffset, int positionCount, @Nullable byte[] rowIsNull, + public static RowOmniBlock createRowBlockInternal(int startOffset, int positionCount, @Nullable byte[] rowIsNull, int[] fieldBlockOffsets, Block[] fieldBlocks, DataType dataType) { validateConstructorArguments(startOffset, positionCount, rowIsNull, fieldBlockOffsets, fieldBlocks); @@ -179,7 +175,7 @@ public class RowOmniBlock long nativeVectorAddress = vec.getNativeVector(); vectorAddresses[i] = nativeVectorAddress; } - ContainerVec vec = new ContainerVec(vecAllocator, numFields, this.getPositionCount(), vectorAddresses, + ContainerVec vec = new ContainerVec(numFields, this.getPositionCount(), vectorAddresses, ((ContainerDataType) dataType).getFieldTypes()); vec.setNulls(0, this.getRowIsNull(), 0, this.getPositionCount()); return vec; @@ -201,7 +197,6 @@ public class RowOmniBlock Block[] fieldBlocks, DataType dataType, ContainerVec containerVec) { super(fieldBlocks.length); - this.vecAllocator = getVecAllocatorFromBlocks(fieldBlocks); this.startOffset = startOffset; this.positionCount = positionCount; diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/ShortArrayOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/ShortArrayOmniBlock.java index 4af945dd5aba1c620fb3e0ba117f42b348f1f660..f735877448f3012d903a6771eb52bd053d3ddf6e 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/ShortArrayOmniBlock.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/ShortArrayOmniBlock.java @@ -21,21 +21,14 @@ import io.prestosql.spi.block.ShortArrayBlockEncoding; import io.prestosql.spi.util.BloomFilter; import nova.hetu.omniruntime.vector.ShortVec; import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecAllocator; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - -import java.util.Optional; import java.util.function.BiConsumer; import java.util.function.Function; -import static io.airlift.slice.SizeOf.sizeOf; import static io.prestosql.spi.block.BlockUtil.checkArrayRange; import static io.prestosql.spi.block.BlockUtil.checkValidRegion; -import static io.prestosql.spi.block.BlockUtil.compactArray; import static io.prestosql.spi.block.BlockUtil.countUsedPositions; -import static nova.hetu.olk.tool.BlockUtils.compactVec; /** * The type Short array omni block. @@ -47,33 +40,15 @@ public class ShortArrayOmniBlock { private static final int INSTANCE_SIZE = ClassLayout.parseClass(ShortArrayOmniBlock.class).instanceSize(); - private final VecAllocator vecAllocator; - - private final int arrayOffset; - private final int positionCount; - @Nullable - private final byte[] valueIsNull; - private final ShortVec values; private final long sizeInBytes; private final long retainedSizeInBytes; - /** - * Instantiates a new Short array omni block. - * - * @param vecAllocator the vector allocator - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - public ShortArrayOmniBlock(VecAllocator vecAllocator, int positionCount, Optional valueIsNull, short[] values) - { - this(vecAllocator, 0, positionCount, valueIsNull.orElse(null), values); - } + private boolean hasNull; /** * Instantiates a new Short array omni block. @@ -83,33 +58,23 @@ public class ShortArrayOmniBlock */ public ShortArrayOmniBlock(int positionCount, ShortVec values) { - this(positionCount, values.hasNullValue() ? Optional.of(values.getRawValueNulls()) : Optional.empty(), values); - } - - /** - * Instantiates a new Short array omni block. - * - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - public ShortArrayOmniBlock(int positionCount, Optional valueIsNull, ShortVec values) - { - this(values.getOffset(), positionCount, valueIsNull.orElse(null), values); + this.positionCount = positionCount; + this.values = values; + this.sizeInBytes = (Short.BYTES + Byte.BYTES) * (long) positionCount; + this.retainedSizeInBytes = INSTANCE_SIZE + this.values.getCapacityInBytes(); + this.hasNull = values.hasNull(); } /** * Instantiates a new Short array omni block. * - * @param vecAllocator the vector allocator * @param arrayOffset the array offset * @param positionCount the position count * @param valueIsNull the value is null * @param values the values */ - public ShortArrayOmniBlock(VecAllocator vecAllocator, int arrayOffset, int positionCount, byte[] valueIsNull, short[] values) + public ShortArrayOmniBlock(int arrayOffset, int positionCount, byte[] valueIsNull, short[] values) { - this.vecAllocator = vecAllocator; if (arrayOffset < 0) { throw new IllegalArgumentException("arrayOffset is negative"); } @@ -123,7 +88,7 @@ public class ShortArrayOmniBlock throw new IllegalArgumentException("values length is less than positionCount"); } - this.values = new ShortVec(vecAllocator, positionCount); + this.values = new ShortVec(positionCount); this.values.put(values, 0, arrayOffset, positionCount); if (valueIsNull != null && valueIsNull.length - arrayOffset < positionCount) { @@ -132,51 +97,11 @@ public class ShortArrayOmniBlock if (valueIsNull != null) { this.values.setNulls(0, valueIsNull, arrayOffset, positionCount); - this.valueIsNull = compactArray(valueIsNull, arrayOffset, positionCount); - } - else { - this.valueIsNull = null; + this.hasNull = true; } - this.arrayOffset = 0; - sizeInBytes = (Short.BYTES + Byte.BYTES) * (long) positionCount; - retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + this.values.getCapacityInBytes(); - } - - /** - * Instantiates a new Short array omni block. - * - * @param arrayOffset the array offset - * @param positionCount the position count - * @param valueIsNull the value is null - * @param values the values - */ - ShortArrayOmniBlock(int arrayOffset, int positionCount, byte[] valueIsNull, ShortVec values) - { - this.vecAllocator = values.getAllocator(); - if (arrayOffset < 0) { - throw new IllegalArgumentException("arrayOffset is negative"); - } - this.arrayOffset = arrayOffset; - - if (positionCount < 0) { - throw new IllegalArgumentException("positionCount is negative"); - } - this.positionCount = positionCount; - - if (values.getSize() < positionCount) { - throw new IllegalArgumentException("values length is less than positionCount"); - } - this.values = values; - - if (valueIsNull != null && valueIsNull.length - arrayOffset < positionCount) { - throw new IllegalArgumentException("isNull length is less than positionCount"); - } - this.valueIsNull = valueIsNull; - - sizeInBytes = (Short.BYTES + Byte.BYTES) * (long) positionCount; - retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + this.values.getCapacityInBytes(); + retainedSizeInBytes = INSTANCE_SIZE + this.values.getCapacityInBytes(); } @Override @@ -225,9 +150,6 @@ public class ShortArrayOmniBlock public void retainedBytesForEachPart(BiConsumer consumer) { consumer.accept(values.get(0, positionCount), (long) values.getCapacityInBytes()); - if (valueIsNull != null) { - consumer.accept(valueIsNull, sizeOf(valueIsNull)); - } consumer.accept(this, (long) INSTANCE_SIZE); } @@ -262,14 +184,14 @@ public class ShortArrayOmniBlock @Override public boolean mayHaveNull() { - return valueIsNull != null; + return hasNull; } @Override public boolean isNull(int position) { checkReadablePosition(position); - return valueIsNull != null && valueIsNull[position + arrayOffset] == Vec.NULL; + return values.isNull(position); } @Override @@ -284,7 +206,7 @@ public class ShortArrayOmniBlock public Block getSingleValueBlock(int position) { checkReadablePosition(position); - return new ShortArrayOmniBlock(vecAllocator, 0, 1, isNull(position) ? new byte[]{Vec.NULL} : null, + return new ShortArrayOmniBlock(0, 1, isNull(position) ? new byte[]{Vec.NULL} : null, new short[]{values.get(position)}); } @@ -292,20 +214,16 @@ public class ShortArrayOmniBlock public Block copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); - byte[] newValueIsNull = null; ShortVec newValues = values.copyPositions(positions, offset, length); - if (valueIsNull != null) { - newValueIsNull = newValues.getRawValueNulls(); - } - return new ShortArrayOmniBlock(0, length, newValueIsNull, newValues); + return new ShortArrayOmniBlock(length, newValues); } @Override public Block getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); - ShortVec newValues = values.slice(positionOffset, positionOffset + length); - return new ShortArrayOmniBlock(newValues.getOffset(), length, valueIsNull, newValues); + ShortVec newValues = values.slice(positionOffset, length); + return new ShortArrayOmniBlock(length, newValues); } @Override @@ -313,15 +231,9 @@ public class ShortArrayOmniBlock { checkValidRegion(getPositionCount(), positionOffset, length); - ShortVec newValues = compactVec(values, positionOffset, length); - byte[] newValueIsNull = valueIsNull == null - ? null - : compactArray(valueIsNull, positionOffset + arrayOffset, length); - - if (newValueIsNull == valueIsNull && newValues == values) { - return this; - } - return new ShortArrayOmniBlock(0, length, newValueIsNull, newValues); + ShortVec newValues = values.slice(positionOffset, length); + values.close(); + return new ShortArrayOmniBlock(length, newValues); } @Override @@ -360,7 +272,7 @@ public class ShortArrayOmniBlock { int matchCount = 0; for (int i = 0; i < positionCount; i++) { - if (valueIsNull != null && valueIsNull[positions[i] + arrayOffset] == Vec.NULL) { + if (values.isNull(positions[i])) { if (test.apply(null)) { matchedPositions[matchCount++] = positions[i]; } @@ -376,7 +288,7 @@ public class ShortArrayOmniBlock @Override public Short get(int position) { - if (valueIsNull != null && valueIsNull[position + arrayOffset] == Vec.NULL) { + if (values.isNull(position)) { return null; } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/VariableWidthOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/VariableWidthOmniBlock.java index 1862dafe44f74b146052f957de687191213c85f7..3ded233e0c5781d361ee30c25eead1f17a013fd2 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/VariableWidthOmniBlock.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/VariableWidthOmniBlock.java @@ -20,25 +20,15 @@ import io.airlift.slice.Slices; import io.prestosql.spi.block.AbstractVariableWidthBlock; import io.prestosql.spi.block.Block; import io.prestosql.spi.util.BloomFilter; -import nova.hetu.omniruntime.vector.JvmUtils; import nova.hetu.omniruntime.vector.VarcharVec; -import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecAllocator; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - -import java.nio.ByteBuffer; -import java.util.Optional; import java.util.function.BiConsumer; import java.util.function.Function; import static io.airlift.slice.SizeOf.sizeOf; import static io.prestosql.spi.block.BlockUtil.checkArrayRange; import static io.prestosql.spi.block.BlockUtil.checkValidRegion; -import static io.prestosql.spi.block.BlockUtil.compactArray; -import static io.prestosql.spi.block.BlockUtil.compactOffsets; -import static nova.hetu.olk.tool.BlockUtils.compactVec; /** * The type Variable width omni block. @@ -50,47 +40,30 @@ public class VariableWidthOmniBlock { private static final int INSTANCE_SIZE = ClassLayout.parseClass(VariableWidthOmniBlock.class).instanceSize(); - private final int arrayOffset; - private final int positionCount; private final VarcharVec values; - private final int[] offsets; + private int[] offsets; - @Nullable - private final byte[] valueIsNull; + private Slice slice; private final long retainedSizeInBytes; private final long sizeInBytes; - /** - * Instantiates a new Variable width omni block. - * - * @param vecAllocator vector allocator - * @param positionCount the position count - * @param slice the slice - * @param offsets the offsets - * @param valueIsNull the value is null - */ - public VariableWidthOmniBlock(VecAllocator vecAllocator, int positionCount, Slice slice, int[] offsets, - Optional valueIsNull) - { - this(vecAllocator, 0, positionCount, slice, offsets, valueIsNull.orElse(null)); - } + private boolean hasNull; /** * Instantiates a new Variable width omni block. * - * @param vecAllocator vector allocator * @param arrayOffset the array offset * @param positionCount the position count * @param slice the slice * @param offsets the offsets * @param valueIsNull the value is null */ - public VariableWidthOmniBlock(VecAllocator vecAllocator, int arrayOffset, int positionCount, Slice slice, int[] offsets, + public VariableWidthOmniBlock(int arrayOffset, int positionCount, Slice slice, int[] offsets, byte[] valueIsNull) { if (arrayOffset < 0) { @@ -107,12 +80,11 @@ public class VariableWidthOmniBlock } int dataLength = offsets[arrayOffset + positionCount] - offsets[arrayOffset]; - this.values = new VarcharVec(vecAllocator, dataLength, positionCount); + this.values = new VarcharVec(dataLength, positionCount); if (offsets.length - arrayOffset < (positionCount + 1)) { throw new IllegalArgumentException("offsets length is less than positionCount"); } - this.offsets = compactOffsets(offsets, arrayOffset, positionCount); if (slice.hasByteArray()) { this.values.put(0, slice.byteArray(), slice.byteArrayOffset(), offsets, arrayOffset, positionCount); @@ -124,17 +96,12 @@ public class VariableWidthOmniBlock if (valueIsNull != null) { this.values.setNulls(0, valueIsNull, arrayOffset, positionCount); - this.valueIsNull = compactArray(valueIsNull, arrayOffset, positionCount); - } - else { - this.valueIsNull = null; + this.hasNull = true; } - this.arrayOffset = 0; - - sizeInBytes = offsets[this.arrayOffset + positionCount] - offsets[this.arrayOffset] + sizeInBytes = offsets[arrayOffset + positionCount] - offsets[arrayOffset] + ((Integer.BYTES + Byte.BYTES) * (long) positionCount); - retainedSizeInBytes = INSTANCE_SIZE + slice.getRetainedSize() + sizeOf(valueIsNull) + sizeOf(offsets); + retainedSizeInBytes = INSTANCE_SIZE + slice.getRetainedSize() + sizeOf(valueIsNull); } /** @@ -145,95 +112,49 @@ public class VariableWidthOmniBlock */ public VariableWidthOmniBlock(int positionCount, VarcharVec values) { - this(positionCount, values, values.getRawValueOffset(), - values.hasNullValue() ? Optional.of(values.getRawValueNulls()) : Optional.empty()); - } - - /** - * Instantiates a new Variable width omni block. - * - * @param positionCount the position count - * @param values the values - * @param offsets the offsets - * @param valuesIsNull the values is null - */ - public VariableWidthOmniBlock(int positionCount, VarcharVec values, int[] offsets, Optional valuesIsNull) - { - this(values.getOffset(), positionCount, values, offsets, valuesIsNull.orElse(null)); - } - - /** - * Instantiates a new Variable width omni block. - * - * @param arrayOffset the array offset - * @param positionCount the position count - * @param values the values - * @param offsets the offsets - * @param valueIsNull the value is null - */ - public VariableWidthOmniBlock(int arrayOffset, int positionCount, VarcharVec values, int[] offsets, - byte[] valueIsNull) - { - if (arrayOffset < 0) { - throw new IllegalArgumentException("arrayOffset is negative"); - } - - if (positionCount < 0) { - throw new IllegalArgumentException("positionCount is negative"); - } this.positionCount = positionCount; - - if (values == null) { - throw new IllegalArgumentException("values is null"); - } - this.values = values; + this.sizeInBytes = getPositionOffset(positionCount) + ((Integer.BYTES + Byte.BYTES) * (long) positionCount); + this.retainedSizeInBytes = INSTANCE_SIZE + values.getCapacityInBytes(); + this.hasNull = values.hasNull(); + } - if (offsets != null && offsets.length - arrayOffset < (positionCount + 1)) { - throw new IllegalArgumentException("offsets length is less than positionCount"); - } - - if (offsets == null) { - throw new IllegalArgumentException("offsets is null"); + private void loadOffset() + { + if (offsets != null) { + return; } - this.offsets = offsets; - - if (valueIsNull != null && valueIsNull.length - arrayOffset < positionCount) { - throw new IllegalArgumentException("valueIsNull length is less than positionCount"); + offsets = new int[positionCount + 1]; + for (int i = 0; i < positionCount; i++) { + offsets[i + 1] = offsets[i] + getSliceLength(i); } - - this.valueIsNull = valueIsNull; - this.arrayOffset = arrayOffset; - - sizeInBytes = offsets[arrayOffset + positionCount] - offsets[arrayOffset] - + ((Integer.BYTES + Byte.BYTES) * (long) positionCount); - retainedSizeInBytes = INSTANCE_SIZE + values.getCapacityInBytes() + sizeOf(valueIsNull) + sizeOf(offsets); } @Override protected final int getPositionOffset(int position) { - return offsets[position + arrayOffset]; + loadOffset(); + return offsets[position]; } @Override public int getSliceLength(int position) { checkReadablePosition(position); - return getPositionOffset(position + 1) - getPositionOffset(position); + return values.getDataLength(position); } @Override public boolean mayHaveNull() { - return valueIsNull != null; + return hasNull; } @Override protected boolean isEntryNull(int position) { - return valueIsNull != null && valueIsNull[position + arrayOffset] == Vec.NULL; + return values.isNull(position); } @Override @@ -251,8 +172,12 @@ public class VariableWidthOmniBlock @Override public long getRegionSizeInBytes(int position, int length) { - return offsets[arrayOffset + position + length] - offsets[arrayOffset + position] - + ((Integer.BYTES + Byte.BYTES) * (long) length); + int sliceLength = 0; + for (int i = 0; i < length; i++) { + sliceLength += values.getDataLength(position); + position++; + } + return sliceLength + ((Integer.BYTES + Byte.BYTES) * (long) length); } @Override @@ -263,7 +188,7 @@ public class VariableWidthOmniBlock for (int i = 0; i < positions.length; ++i) { if (positions[i]) { usedPositionCount++; - sizeInBytes += offsets[arrayOffset + i + 1] - offsets[arrayOffset + i]; + sizeInBytes += values.getDataLength(i); } } return sizeInBytes + (Integer.BYTES + Byte.BYTES) * (long) usedPositionCount; @@ -279,10 +204,6 @@ public class VariableWidthOmniBlock public void retainedBytesForEachPart(BiConsumer consumer) { consumer.accept(getRawSlice(0), (long) values.getCapacityInBytes()); - consumer.accept(offsets, sizeOf(offsets)); - if (valueIsNull != null) { - consumer.accept(valueIsNull, sizeOf(valueIsNull)); - } consumer.accept(this, (long) INSTANCE_SIZE); } @@ -290,51 +211,40 @@ public class VariableWidthOmniBlock public Block copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); - byte[] newValueIsNull = null; VarcharVec newValues = values.copyPositions(positions, offset, length); - if (valueIsNull != null) { - newValueIsNull = newValues.getRawValueNulls(); + return new VariableWidthOmniBlock(length, newValues); + } + + private void loadSlice() + { + if (slice != null) { + return; } - int[] newOffsets = newValues.getRawValueOffset(); - return new VariableWidthOmniBlock(0, length, newValues, newOffsets, newValueIsNull); + slice = Slices.wrappedBuffer(values.get(0, positionCount)); } @Override public Slice getRawSlice(int position) { - // use slice wrapped byteBuffer for zero-copy data - ByteBuffer valuesBuf = JvmUtils.directBuffer(values.getValuesBuf()); - valuesBuf.position(0); - if (valuesBuf.capacity() != 0) { - return Slices.wrappedBuffer(valuesBuf); - } - - // empty values - return Slices.wrappedBuffer(); + loadSlice(); + return slice; } @Override public Block getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); - VarcharVec newValues = values.slice(positionOffset, positionOffset + length); - return new VariableWidthOmniBlock(newValues.getOffset(), length, newValues, offsets, valueIsNull); + VarcharVec newValues = values.slice(positionOffset, length); + return new VariableWidthOmniBlock(length, newValues); } @Override public Block copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); - - int[] newOffsets = compactOffsets(offsets, positionOffset + arrayOffset, length); - VarcharVec newValues = compactVec(values, positionOffset, length); - byte[] newValueIsNull = valueIsNull == null ? null : compactArray(valueIsNull, positionOffset, length); - - if (newOffsets == offsets && newValues == values && newValueIsNull == valueIsNull) { - return this; - } - - return new VariableWidthOmniBlock(0, length, newValues, offsets, valueIsNull); + VarcharVec newValues = values.slice(positionOffset, length); + values.close(); + return new VariableWidthOmniBlock(length, newValues); } @Override @@ -351,8 +261,7 @@ public class VariableWidthOmniBlock public boolean[] filter(BloomFilter filter, boolean[] validPositions) { for (int i = 0; i < positionCount; i++) { - byte[] value = values.getData(offsets[i + arrayOffset], - offsets[i + arrayOffset + 1] - offsets[i + arrayOffset]); + byte[] value = values.get(i); validPositions[i] = validPositions[i] && filter.test(value); } return validPositions; @@ -363,14 +272,13 @@ public class VariableWidthOmniBlock { int matchCount = 0; for (int i = 0; i < positionCount; i++) { - if (valueIsNull != null && valueIsNull[positions[i] + arrayOffset] == Vec.NULL) { + if (values.isNull(positions[i])) { if (test.apply(null)) { matchedPositions[matchCount++] = positions[i]; } } else { - byte[] value = values.getData(offsets[i + arrayOffset], - offsets[i + arrayOffset + 1] - offsets[i + arrayOffset]); + byte[] value = values.get(positions[i]); if (test.apply(value)) { matchedPositions[matchCount++] = positions[i]; } @@ -382,10 +290,10 @@ public class VariableWidthOmniBlock @Override public byte[] get(int position) { - if (valueIsNull != null && valueIsNull[position + arrayOffset] == Vec.NULL) { + if (values.isNull(position)) { return null; } - return values.getData(offsets[position + arrayOffset], offsets[position + arrayOffset + 1] - offsets[position + arrayOffset]); + return values.get(position); } @Override diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/memory/OpenLooKengAllocatorFactory.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/memory/OpenLooKengAllocatorFactory.java deleted file mode 100644 index 49bb558d84ab1e800ce49b570974b387430ff002..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/memory/OpenLooKengAllocatorFactory.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package nova.hetu.olk.memory; - -import nova.hetu.omniruntime.vector.VecAllocator; - -import java.util.HashMap; -import java.util.Map; - -public class OpenLooKengAllocatorFactory -{ - private static Map vecAllocators = new HashMap(); - - private OpenLooKengAllocatorFactory() - { - } - - /** - * create the vector allocator with specified scope and call back. - * - * @param scope scope the specified scope - * @param createCallBack createCallBack the call back - * @return vector allocator - */ - public static synchronized VecAllocator create(String scope, CallBack createCallBack) - { - VecAllocator allocator = vecAllocators.get(scope); - if (allocator == null) { - allocator = VecAllocator.GLOBAL_VECTOR_ALLOCATOR.newChildAllocator(scope, VecAllocator.UNLIMIT, 0); - vecAllocators.put(scope, new OpenLooKengVecAllocator(allocator.getNativeAllocator())); - if (createCallBack != null) { - createCallBack.callBack(); - } - } - return allocator; - } - - /** - * get the vector allocator by specified scope - * - * @param scope scope the scope for vector - * @return vector allocator - */ - public static synchronized VecAllocator get(String scope) - { - if (vecAllocators.containsKey(scope)) { - return vecAllocators.get(scope); - } - return VecAllocator.GLOBAL_VECTOR_ALLOCATOR; - } - - /** - * delete the vector allocator by specified scope. - * - * @param scope scope the scope for vector - */ - public static synchronized void delete(String scope) - { - VecAllocator allocator = vecAllocators.get(scope); - if (allocator != null) { - vecAllocators.remove(scope); - allocator.close(); - } - } - - /** - * remove this allocator from vecAllocators - * - * @param scope scope the scope for vector - * @return removed allocator - */ - public static synchronized VecAllocator remove(String scope) - { - VecAllocator allocator = vecAllocators.get(scope); - if (allocator != null) { - vecAllocators.remove(scope); - } - return allocator; - } - - /** - * the call back interface - * - * @since 2022-05-16 - */ - public interface CallBack - { - void callBack(); - } -} diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/memory/OpenLooKengVecAllocator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/memory/OpenLooKengMemoryManager.java similarity index 76% rename from omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/memory/OpenLooKengVecAllocator.java rename to omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/memory/OpenLooKengMemoryManager.java index bc8825dea9e86fc426821ddd487ab56c70d56d0a..3bef2bd0191ba29f1daeea92628af174bcb3ba45 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/memory/OpenLooKengVecAllocator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/memory/OpenLooKengMemoryManager.java @@ -16,19 +16,19 @@ package nova.hetu.olk.memory; import io.airlift.log.Logger; -import nova.hetu.omniruntime.vector.VecAllocator; +import nova.hetu.omniruntime.memory.MemoryManager; import java.util.concurrent.atomic.AtomicBoolean; -public class OpenLooKengVecAllocator - extends VecAllocator +public class OpenLooKengMemoryManager + extends MemoryManager { - private static final Logger log = Logger.get(OpenLooKengVecAllocator.class); + private static final Logger log = Logger.get(OpenLooKengMemoryManager.class); private final AtomicBoolean isClosed = new AtomicBoolean(false); - public OpenLooKengVecAllocator(long nativeAllocator) + public OpenLooKengMemoryManager() { - super(nativeAllocator); + super(); } @Override @@ -46,7 +46,6 @@ public class OpenLooKengVecAllocator public void close() { if (isClosed.compareAndSet(false, true)) { - log.debug("release allocator scope:" + getScope()); super.close(); } } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AggregationOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AggregationOmniOperator.java index 69478201316236ad2434234da0f632e84a979137..643c7b7071a8dcedf403451a0052ad3ab8e1b2b3 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AggregationOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AggregationOmniOperator.java @@ -24,13 +24,11 @@ import io.prestosql.spi.plan.AggregationNode; import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.Type; import nova.hetu.olk.tool.OperatorUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; import nova.hetu.olk.tool.VecBatchToPageIterator; import nova.hetu.omniruntime.constants.FunctionType; import nova.hetu.omniruntime.operator.OmniOperator; import nova.hetu.omniruntime.operator.aggregator.OmniAggregationOperatorFactory; import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; import java.util.List; @@ -84,7 +82,7 @@ public class AggregationOmniOperator checkState(needsInput(), "Operator is already finishing"); requireNonNull(page, "page is null"); - VecBatch vecBatch = buildVecBatch(omniOperator.getVecAllocator(), page, this); + VecBatch vecBatch = buildVecBatch(page, this); omniOperator.addInput(vecBatch); } @@ -200,11 +198,9 @@ public class AggregationOmniOperator @Override public Operator createOperator(DriverContext driverContext) { - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, AggregationOmniOperator.class); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, AggregationOmniOperator.class.getSimpleName()); - OmniOperator omniOperator = omniFactory.createOperator(vecAllocator); + OmniOperator omniOperator = omniFactory.createOperator(); return new AggregationOmniOperator(operatorContext, omniOperator); } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/BuildOffHeapOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/BuildOffHeapOmniOperator.java index ea1201e563170afe2049dc07ec48471d63f205d2..bf159a90d0d22cacae0dab47020e2e52433557c9 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/BuildOffHeapOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/BuildOffHeapOmniOperator.java @@ -26,8 +26,6 @@ import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.StandardTypes; import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeSignature; -import nova.hetu.olk.tool.VecAllocatorHelper; -import nova.hetu.omniruntime.vector.VecAllocator; import java.util.List; @@ -45,8 +43,6 @@ public class BuildOffHeapOmniOperator { private final OperatorContext operatorContext; - private final VecAllocator vecAllocator; - private boolean finishing; private Page inputPage; @@ -57,12 +53,10 @@ public class BuildOffHeapOmniOperator * Instantiates a new BuildOffHeap Omni Operator. * * @param operatorContext the operator context - * @param vecAllocator the vecAllocator */ - public BuildOffHeapOmniOperator(OperatorContext operatorContext, VecAllocator vecAllocator, List inputTypes) + public BuildOffHeapOmniOperator(OperatorContext operatorContext, List inputTypes) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); - this.vecAllocator = vecAllocator; this.inputTypes = inputTypes; } @@ -112,7 +106,7 @@ public class BuildOffHeapOmniOperator private Page processPage() { - return transferToOffHeapPages(vecAllocator, inputPage, inputTypes); + return transferToOffHeapPages(inputPage, inputTypes); } /** @@ -140,11 +134,9 @@ public class BuildOffHeapOmniOperator @Override public Operator createOperator(DriverContext driverContext) { - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, VecAllocatorHelper.DEFAULT_RESERVATION, BuildOffHeapOmniOperator.class); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, BuildOffHeapOmniOperator.class.getSimpleName()); - return new BuildOffHeapOmniOperator(operatorContext, vecAllocator, sourceTypes); + return new BuildOffHeapOmniOperator(operatorContext, sourceTypes); } @Override diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/DistinctLimitOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/DistinctLimitOmniOperator.java index 0d33baf35aaf2b6c7cca0b970be6eb345e36f30a..38ee08299ebfaf31282b6932d05c7a3ca70d803e 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/DistinctLimitOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/DistinctLimitOmniOperator.java @@ -25,12 +25,10 @@ import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.Type; import nova.hetu.olk.tool.BlockUtils; import nova.hetu.olk.tool.OperatorUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; import nova.hetu.olk.tool.VecBatchToPageIterator; import nova.hetu.omniruntime.operator.OmniOperator; import nova.hetu.omniruntime.operator.limit.OmniDistinctLimitOperatorFactory; import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; import java.util.Iterator; @@ -106,11 +104,9 @@ public class DistinctLimitOmniOperator @Override public Operator createOperator(DriverContext driverContext) { - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, DistinctLimitOmniOperator.class); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, DistinctLimitOmniOperator.class.getSimpleName()); - OmniOperator omniOperator = omniDistinctLimitOperatorFactory.createOperator(vecAllocator); + OmniOperator omniOperator = omniDistinctLimitOperatorFactory.createOperator(); int[] distinctChannelArray = distinctChannels.stream().mapToInt(Integer::intValue).toArray(); int hashChannelVal = this.hashChannel.orElse(-1); @@ -227,7 +223,7 @@ public class DistinctLimitOmniOperator return; } - VecBatch vecBatch = buildVecBatch(omniOperator.getVecAllocator(), page, getClass().getSimpleName()); + VecBatch vecBatch = buildVecBatch(page, this); omniOperator.addInput(vecBatch); pages = new VecBatchToPageIterator(omniOperator.getOutput()); } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/DynamicFilterSourceOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/DynamicFilterSourceOmniOperator.java index a7b294ebc4f183a973cce6457b6b88c37a64fbb9..96370fd34c6ef8e95b2b595fa24e4da0fbee41ca 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/DynamicFilterSourceOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/DynamicFilterSourceOmniOperator.java @@ -24,8 +24,6 @@ import io.prestosql.spi.Page; import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.Type; import nova.hetu.olk.tool.BlockUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; -import nova.hetu.omniruntime.vector.VecAllocator; import java.util.List; import java.util.Map; @@ -44,7 +42,6 @@ import static java.util.stream.Collectors.toSet; public class DynamicFilterSourceOmniOperator extends DynamicFilterSourceOperator { - private VecAllocator vecAllocator; private Page page; /** @@ -52,10 +49,9 @@ public class DynamicFilterSourceOmniOperator */ public DynamicFilterSourceOmniOperator(OperatorContext context, Consumer> dynamicPredicateConsumer, List channels, PlanNodeId planNodeId, - int maxFilterPositionsCount, DataSize maxFilterSize, VecAllocator vecAllocator) + int maxFilterPositionsCount, DataSize maxFilterSize) { super(context, dynamicPredicateConsumer, channels, planNodeId, maxFilterPositionsCount, maxFilterSize); - this.vecAllocator = vecAllocator; } @Override @@ -113,13 +109,10 @@ public class DynamicFilterSourceOmniOperator public DynamicFilterSourceOperator createOperator(DriverContext driverContext) { - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, DynamicFilterSourceOmniOperator.class); return new DynamicFilterSourceOmniOperator( driverContext.addOperatorContext(operatorId, planNodeId, DynamicFilterSourceOmniOperator.class.getSimpleName()), - dynamicPredicateConsumer, channels, planNodeId, maxFilterPositionsCount, maxFilterSize, - vecAllocator); + dynamicPredicateConsumer, channels, planNodeId, maxFilterPositionsCount, maxFilterSize); } @Override diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/EnforceSingleRowOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/EnforceSingleRowOmniOperator.java index 47aed9bf48f40b709947faa5f82e2d470b5f2297..5685fcf6fdea96a2e1f00f312794469991f4586e 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/EnforceSingleRowOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/EnforceSingleRowOmniOperator.java @@ -25,8 +25,6 @@ import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.Type; import nova.hetu.olk.tool.BlockUtils; import nova.hetu.olk.tool.OperatorUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; -import nova.hetu.omniruntime.vector.VecAllocator; import java.util.List; @@ -39,13 +37,11 @@ import static java.util.Objects.requireNonNull; public class EnforceSingleRowOmniOperator extends EnforceSingleRowOperator { - private VecAllocator vecAllocator; private Page page; - public EnforceSingleRowOmniOperator(OperatorContext operatorContext, VecAllocator vecAllocator) + public EnforceSingleRowOmniOperator(OperatorContext operatorContext) { super(operatorContext); - this.vecAllocator = vecAllocator; } @Override @@ -71,7 +67,7 @@ public class EnforceSingleRowOmniOperator return null; } page = null; - return OperatorUtils.transferToOffHeapPages(vecAllocator, output); + return OperatorUtils.transferToOffHeapPages(output); } @Override @@ -109,9 +105,7 @@ public class EnforceSingleRowOmniOperator checkState(!closed, "Factory is already closed"); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, EnforceSingleRowOmniOperator.class.getSimpleName()); - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, EnforceSingleRowOmniOperator.class); - return new EnforceSingleRowOmniOperator(operatorContext, vecAllocator); + return new EnforceSingleRowOmniOperator(operatorContext); } @Override diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashAggregationOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashAggregationOmniOperator.java index 0bf42ef007a34bd54315a6c0bebf46d35e9c7c6f..4efaa7348bc0e484f9260c8d196d33a3cb090bc4 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashAggregationOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashAggregationOmniOperator.java @@ -33,13 +33,11 @@ import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.StandardTypes; import io.prestosql.spi.type.Type; import nova.hetu.olk.tool.BlockUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; import nova.hetu.olk.tool.VecBatchToPageIterator; import nova.hetu.omniruntime.constants.FunctionType; import nova.hetu.omniruntime.operator.OmniOperator; import nova.hetu.omniruntime.operator.aggregator.OmniHashAggregationOperatorFactory; import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; import java.util.Arrays; @@ -133,7 +131,7 @@ public class HashAggregationOmniOperator { checkState(!finishing, "Operator is already finishing"); requireNonNull(page, "page is null"); - VecBatch vecBatch = buildVecBatch(omniOperator.getVecAllocator(), page, this); + VecBatch vecBatch = buildVecBatch(page, this); omniOperator.addInput(vecBatch); } @@ -309,11 +307,9 @@ public class HashAggregationOmniOperator @Override public Operator createOperator(DriverContext driverContext) { - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, HashAggregationOmniOperator.class); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, HashAggregationOmniOperator.class.getSimpleName()); - OmniOperator omniOperator = omniFactory.createOperator(vecAllocator); + OmniOperator omniOperator = omniFactory.createOperator(); return new HashAggregationOmniOperator(operatorContext, omniOperator, step); } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashBuilderOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashBuilderOmniOperator.java index 3b9efaf18a244f13ed32e94aeea0f9096ce0e4cc..f1e9195f815d71dda859d3c66b83b29c43e4290f 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashBuilderOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashBuilderOmniOperator.java @@ -35,11 +35,9 @@ import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.Type; import nova.hetu.olk.tool.BlockUtils; import nova.hetu.olk.tool.OperatorUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; import nova.hetu.omniruntime.operator.OmniOperator; import nova.hetu.omniruntime.operator.join.OmniHashBuilderOperatorFactory; import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; import javax.annotation.concurrent.ThreadSafe; @@ -131,8 +129,6 @@ public class HashBuilderOmniOperator public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, HashBuilderOmniOperator.class); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, HashBuilderOmniOperator.class.getSimpleName()); @@ -141,7 +137,7 @@ public class HashBuilderOmniOperator int partitionIndex = getAndIncrementPartitionIndex(driverContext.getLifespan()); verify(partitionIndex < lookupSourceFactory.partitions()); - OmniOperator omniOperator = omniHashBuilderOperatorFactory.createOperator(vecAllocator); + OmniOperator omniOperator = omniHashBuilderOperatorFactory.createOperator(); return new HashBuilderOmniOperator(operatorContext, lookupSourceFactory, partitionIndex, omniOperator); } @@ -290,7 +286,7 @@ public class HashBuilderOmniOperator return; } - VecBatch vecBatch = buildVecBatch(omniOperator.getVecAllocator(), page, this); + VecBatch vecBatch = buildVecBatch(page, this); omniOperator.addInput(vecBatch); operatorContext.recordOutput(page.getSizeInBytes(), positionCount); diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LimitOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LimitOmniOperator.java index 6863feae5ffd1e80a7a04008bb5dbb04ddcb345c..f7db6575a268881fdc6e3b7877ef345c8a338e0b 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LimitOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LimitOmniOperator.java @@ -20,23 +20,15 @@ import io.prestosql.operator.Operator; import io.prestosql.operator.OperatorContext; import io.prestosql.operator.OperatorFactory; import io.prestosql.spi.Page; +import io.prestosql.spi.block.Block; import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.Type; import nova.hetu.olk.tool.BlockUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; -import nova.hetu.olk.tool.VecBatchToPageIterator; -import nova.hetu.omniruntime.operator.OmniOperator; -import nova.hetu.omniruntime.operator.limit.OmniLimitOperatorFactory; -import nova.hetu.omniruntime.vector.VecAllocator; -import nova.hetu.omniruntime.vector.VecBatch; - -import java.util.Iterator; + import java.util.List; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; -import static nova.hetu.olk.tool.OperatorUtils.buildVecBatch; /** * The type limit omni operator. @@ -48,55 +40,43 @@ public class LimitOmniOperator { private long remainingLimit; - private boolean finishing; - - private boolean finished; - private final OperatorContext operatorContext; - private final OmniOperator omniOperator; - - private Iterator pages; // The Pages + private Page nextPage; /** * Instantiates a new Top n omni operator. * * @param operatorContext the operator context - * @param omniOperator the omni operator * @param limit the limit record count */ - public LimitOmniOperator(OperatorContext operatorContext, OmniOperator omniOperator, long limit) + public LimitOmniOperator(OperatorContext operatorContext, long limit) { checkArgument(limit >= 0, "limit must be at least zero"); this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); - this.omniOperator = omniOperator; this.remainingLimit = limit; - this.pages = null; + this.nextPage = null; } @Override public void finish() { - finishing = true; + remainingLimit = 0; } @Override public boolean isFinished() { - return finished; + return remainingLimit == 0 && nextPage == null; } @Override public void close() throws Exception { - // free page if it has next - if (pages != null) { - while (pages.hasNext()) { - Page next = pages.next(); - BlockUtils.freePage(next); - } + // free page if it is not null + if (nextPage != null) { + BlockUtils.freePage(nextPage); } - omniOperator.close(); } @Override @@ -108,48 +88,41 @@ public class LimitOmniOperator @Override public boolean needsInput() { - if (finishing) { - return false; - } - - return true; + return remainingLimit > 0 && nextPage == null; } @Override public void addInput(Page page) { - checkState(!finishing, "Operator is already finishing"); requireNonNull(page, "page is null"); int rowCount = page.getPositionCount(); - if (remainingLimit == 0 || rowCount == 0) { + if (rowCount == 0 || !needsInput()) { BlockUtils.freePage(page); return; } - remainingLimit = (remainingLimit >= rowCount) ? (remainingLimit - rowCount) : 0; - - VecBatch vecBatch = buildVecBatch(omniOperator.getVecAllocator(), page, getClass().getSimpleName()); - omniOperator.addInput(vecBatch); - pages = new VecBatchToPageIterator(omniOperator.getOutput()); + if (rowCount <= remainingLimit) { + remainingLimit -= rowCount; + nextPage = page; + } + else { + Block[] blocks = new Block[page.getChannelCount()]; + for (int channel = 0; channel < page.getChannelCount(); channel++) { + Block block = page.getBlock(channel); + blocks[channel] = block.getRegion(0, (int) remainingLimit); + } + nextPage = new Page((int) remainingLimit, blocks); + remainingLimit = 0; + BlockUtils.freePage(page); + } } @Override public Page getOutput() { - if (finishing) { - finished = true; - } - - if (pages == null) { - return null; - } - - Page page = null; - if (pages.hasNext()) { - page = pages.next(); - } - pages = null; + Page page = nextPage; + nextPage = null; return page; } @@ -163,8 +136,6 @@ public class LimitOmniOperator { private final long limit; - private final OmniLimitOperatorFactory omniLimitOperatorFactory; - /** * Instantiates a new Top n omni operator factory. * @@ -178,23 +149,14 @@ public class LimitOmniOperator this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); this.limit = limit; this.sourceTypes = sourceTypes; - omniLimitOperatorFactory = getOmniLimitOperatorFactory(limit); - } - - private OmniLimitOperatorFactory getOmniLimitOperatorFactory(long limit) - { - return new OmniLimitOperatorFactory(limit); } @Override public Operator createOperator(DriverContext driverContext) { - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, LimitOmniOperator.class); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, LimitOmniOperator.class.getSimpleName()); - OmniOperator omniOperator = omniLimitOperatorFactory.createOperator(vecAllocator); - return new LimitOmniOperator(operatorContext, omniOperator, limit); + return new LimitOmniOperator(operatorContext, limit); } @Override diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LocalMergeSourceOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LocalMergeSourceOmniOperator.java index 6b09fb6507d5b3740b1d47ae5d267c1feba1c0b9..e9de9baf9070a174b88a479c590a5ea4e618015b 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LocalMergeSourceOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LocalMergeSourceOmniOperator.java @@ -29,8 +29,6 @@ import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.Type; import io.prestosql.sql.gen.OrderingCompiler; import nova.hetu.olk.operator.localexchange.OmniLocalExchange; -import nova.hetu.olk.tool.VecAllocatorHelper; -import nova.hetu.omniruntime.vector.VecAllocator; import java.io.IOException; import java.util.List; @@ -93,8 +91,6 @@ public class LocalMergeSourceOmniOperator public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, LocalMergeSourceOmniOperator.class); LocalExchange localExchange = localExchangeFactory.getLocalExchange(driverContext.getLifespan()); @@ -104,7 +100,7 @@ public class LocalMergeSourceOmniOperator List sources = IntStream.range(0, localExchange.getBufferCount()).boxed() .map(index -> localExchange.getNextSource()).collect(toImmutableList()); return new LocalMergeSourceOmniOperator(operatorContext, sources, - orderByOmniOperatorFactory.createOperator(vecAllocator)); + orderByOmniOperatorFactory.createOperator()); } @Override diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperator.java index 0bed1c1b582989adf87921abfdcfd67692d52fc6..3a81b9d98d5841548b3a313b368654342f37e6fa 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperator.java @@ -38,12 +38,10 @@ import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.Type; import nova.hetu.olk.tool.BlockUtils; import nova.hetu.olk.tool.OperatorUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; import nova.hetu.olk.tool.VecBatchToPageIterator; import nova.hetu.omniruntime.operator.OmniOperator; import nova.hetu.omniruntime.operator.join.OmniLookupJoinOperatorFactory; import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; import java.io.IOException; @@ -191,7 +189,7 @@ public class LookupJoinOmniOperator BlockUtils.freePage(page); return; } - VecBatch vecBatch = buildVecBatch(omniOperator.getVecAllocator(), page, this); + VecBatch vecBatch = buildVecBatch(page, this); omniOperator.addInput(vecBatch); result = new VecBatchToPageIterator(omniOperator.getOutput()); @@ -411,9 +409,6 @@ public class LookupJoinOmniOperator public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, LookupJoinOmniOperator.class); - LookupSourceFactory lookupSourceFactory = joinBridgeManager.getJoinBridge(driverContext.getLifespan()); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, @@ -422,7 +417,7 @@ public class LookupJoinOmniOperator lookupSourceFactory.setTaskContext(driverContext.getPipelineContext().getTaskContext()); joinBridgeManager.probeOperatorCreated(driverContext.getLifespan()); - OmniOperator omniOperator = omniLookupJoinOperatorFactory.createOperator(vecAllocator); + OmniOperator omniOperator = omniLookupJoinOperatorFactory.createOperator(); return new LookupJoinOmniOperator(operatorContext, sourceTypes, joinType, lookupSourceFactory, () -> joinBridgeManager.probeOperatorClosed(driverContext.getLifespan()), omniOperator); } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/MergeOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/MergeOmniOperator.java index 6978dde71517a8bb8742d59e774777d063748d9b..8c8a13354f2d23c68bb84905cff5c3c1f1c3ec30 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/MergeOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/MergeOmniOperator.java @@ -38,8 +38,6 @@ import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.Type; import io.prestosql.split.RemoteSplit; import io.prestosql.sql.gen.OrderingCompiler; -import nova.hetu.olk.tool.VecAllocatorHelper; -import nova.hetu.omniruntime.vector.VecAllocator; import java.io.Closeable; import java.io.IOException; @@ -119,14 +117,12 @@ public class MergeOmniOperator public SourceOperator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, MergeOmniOperator.class); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, MergeOmniOperator.class.getSimpleName()); return new MergeOmniOperator(operatorContext, planNodeId, exchangeClientSupplier, - serdeFactory.createPagesSerde(), orderByOmniOperatorFactory.createOperator(vecAllocator)); + serdeFactory.createPagesSerde(), orderByOmniOperatorFactory.createOperator()); } @Override diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/OrderByOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/OrderByOmniOperator.java index 8229784cbece2e087de175121f1ae56f95a773f2..4519556501fd906b767215346291d2925daa56d4 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/OrderByOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/OrderByOmniOperator.java @@ -34,12 +34,10 @@ import io.prestosql.testing.TestingSession; import io.prestosql.testing.TestingTaskContext; import nova.hetu.olk.tool.BlockUtils; import nova.hetu.olk.tool.OperatorUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; import nova.hetu.olk.tool.VecBatchToPageIterator; import nova.hetu.omniruntime.operator.OmniOperator; import nova.hetu.omniruntime.operator.sort.OmniSortOperatorFactory; import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; import java.util.Iterator; @@ -149,12 +147,10 @@ public class OrderByOmniOperator public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, OrderByOmniOperator.class); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, OrderByOmniOperator.class.getSimpleName()); - OmniOperator omniSortOperator = omniSortOperatorFactory.createOperator(vecAllocator); + OmniOperator omniSortOperator = omniSortOperatorFactory.createOperator(); return new OrderByOmniOperator(operatorContext, omniSortOperator); } @@ -163,7 +159,7 @@ public class OrderByOmniOperator * * @return the operator */ - public Operator createOperator(VecAllocator vecAllocator) + public Operator createOperator() { // all this is prepared for a fake driverContext to avoid change the original // pipeline @@ -181,7 +177,7 @@ public class OrderByOmniOperator OperatorContext mockOperatorContext = mockDriverContext.addOperatorContext(1, new PlanNodeId("Fake node for creating the OrderByOmniOperator"), "OrderByOmniOperator type"); - OmniOperator omniSortOperator = omniSortOperatorFactory.createOperator(vecAllocator); + OmniOperator omniSortOperator = omniSortOperatorFactory.createOperator(); return new OrderByOmniOperator(mockOperatorContext, omniSortOperator); } @@ -259,7 +255,7 @@ public class OrderByOmniOperator return; } - VecBatch vecBatch = buildVecBatch(omniOperator.getVecAllocator(), page, this); + VecBatch vecBatch = buildVecBatch(page, this); omniOperator.addInput(vecBatch); } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/PartitionedOutputOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/PartitionedOutputOmniOperator.java index d86441dc945c4ec7d32a9f5aecfa16704c3cc1cf..6cafad763f327f5e5f975aedba4bfd49583ab26f 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/PartitionedOutputOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/PartitionedOutputOmniOperator.java @@ -43,13 +43,11 @@ import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeSignature; import nova.hetu.olk.tool.BlockUtils; import nova.hetu.olk.tool.OperatorUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; import nova.hetu.olk.tool.VecBatchToPageIterator; import nova.hetu.omniruntime.operator.OmniOperator; import nova.hetu.omniruntime.operator.partitionedoutput.OmniPartitionedOutPutOperatorFactory; import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; import java.util.ArrayList; @@ -104,7 +102,7 @@ public class PartitionedOutputOmniOperator return; } page = pagePreprocessor.apply(page); - partitionFunction.partitionPage(omniOperator.getVecAllocator(), page); + partitionFunction.partitionPage(page); operatorContext.recordOutput(page.getSizeInBytes(), page.getPositionCount()); @@ -252,11 +250,9 @@ public class PartitionedOutputOmniOperator @Override public Operator createOperator(DriverContext driverContext) { - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, PartitionedOutputOmniOperator.class); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, PartitionedOutputOmniOperator.class.getSimpleName()); - OmniOperator omniOperator = omniPartitionedOutPutOperatorFactory.createOperator(vecAllocator); + OmniOperator omniOperator = omniPartitionedOutPutOperatorFactory.createOperator(); String id = operatorContext.getUniqueId(); outputBuffer.addInputChannel(id); return new PartitionedOutputOmniOperator(id, operatorContext, sourceTypes, pagePreprocessor, @@ -430,14 +426,13 @@ public class PartitionedOutputOmniOperator /** * partition Page * - * @param vecAllocator vector allocator * @param page page */ - public void partitionPage(VecAllocator vecAllocator, Page page) + public void partitionPage(Page page) { requireNonNull(page, "page is null"); - VecBatch originalVecBatch = buildVecBatch(vecAllocator, page, this); + VecBatch originalVecBatch = buildVecBatch(page, this); VecBatch originalAndPartitionArgVecBatch = addPartitionFunctionArguments(originalVecBatch); omniOperator.addInput(originalAndPartitionArgVecBatch); @@ -453,8 +448,7 @@ public class PartitionedOutputOmniOperator for (int i = 0; i < partitionChannels.size(); i++) { Optional partitionConstant = partitionConstants.get(i); if (partitionConstant.isPresent()) { - Block block = OperatorUtils.buildOffHeapBlock(omniOperator.getVecAllocator(), - partitionConstant.get()); + Block block = OperatorUtils.buildOffHeapBlock(partitionConstant.get()); // Because there is no vec corresponding to RunLengthEncodedBlock, // the original data is directly constructed. int[] positions = new int[positionCount]; diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/ScanFilterAndProjectOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/ScanFilterAndProjectOmniOperator.java index 5647d439bb00a72d97a931e55dd8243fb126fbde..b6f901d2cbb65816b466971c0469a60b9091c291 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/ScanFilterAndProjectOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/ScanFilterAndProjectOmniOperator.java @@ -70,8 +70,6 @@ import nova.hetu.olk.OmniLocalExecutionPlanner; import nova.hetu.olk.OmniLocalExecutionPlanner.OmniLocalExecutionPlanContext; import nova.hetu.olk.operator.filterandproject.OmniMergePages; import nova.hetu.olk.operator.filterandproject.OmniPageProcessor; -import nova.hetu.olk.tool.VecAllocatorHelper; -import nova.hetu.omniruntime.vector.VecAllocator; import java.io.IOException; import java.io.UncheckedIOException; @@ -110,7 +108,6 @@ public class ScanFilterAndProjectOmniOperator private long processedBytes; private long physicalBytes; private long readTimeNanos; - private VecAllocator vecAllocator; private List inputTypes; private OmniMergePages.OmniMergePagesTransformation omniMergePagesTransformation; @@ -125,14 +122,13 @@ public class ScanFilterAndProjectOmniOperator DataSize minOutputPageSize, int minOutputPageRowCount, Optional tableScanNodeOptional, Optional stateStoreProviderOptional, Optional queryIdOptional, Optional metadataOptional, Optional dynamicFilterCacheManagerOptional, - VecAllocator vecAllocator, List inputTypes, OmniLocalExecutionPlanContext context) + List inputTypes, OmniLocalExecutionPlanContext context) { pages = splits.flatTransform(new SplitToPages(session, yieldSignal, pageSourceProvider, cursorProcessor, pageProcessor, table, columns, dynamicFilter, types, requireNonNull(memoryTrackingContext, "memoryTrackingContext is null").aggregateSystemMemoryContext(), minOutputPageSize, minOutputPageRowCount, tableScanNodeOptional, stateStoreProviderOptional, queryIdOptional, metadataOptional, dynamicFilterCacheManagerOptional, context)); - this.vecAllocator = vecAllocator; this.inputTypes = inputTypes; this.pageProcessor = requireNonNull(pageProcessor, "processor is null"); } @@ -409,7 +405,7 @@ public class ScanFilterAndProjectOmniOperator } pageBuilder.reset(); outputMemoryContext.setBytes(pageBuilder.getRetainedSizeInBytes()); - page = transferToOffHeapPages(vecAllocator, page, outputTypes); + page = transferToOffHeapPages(page, outputTypes); return ProcessState.ofResult(page); } else if (finished) { @@ -516,7 +512,7 @@ public class ScanFilterAndProjectOmniOperator log.error("Filter page error: %s", e.getMessage()); } } - page = transferToOffHeapPages(vecAllocator, page, inputTypes); + page = transferToOffHeapPages(page, inputTypes); return ProcessState.ofResult(page); } @@ -560,8 +556,6 @@ public class ScanFilterAndProjectOmniOperator private final Optional spillerFactory; private final Integer spillerThreshold; private final Integer consumerTableScanNodeCount; - private VecAllocator vecAllocator = VecAllocator.GLOBAL_VECTOR_ALLOCATOR; - private OmniLocalExecutionPlanner.OmniLocalExecutionPlanContext context; public ScanFilterAndProjectOmniOperatorFactory(Session session, int operatorId, PlanNodeId planNodeId, @@ -654,9 +648,6 @@ public class ScanFilterAndProjectOmniOperator checkState(!closed, "Factory is already closed"); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, getOperatorType()); - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, ScanFilterAndProjectOmniOperator.class); - this.vecAllocator = vecAllocator != null ? vecAllocator : VecAllocator.GLOBAL_VECTOR_ALLOCATOR; return new WorkProcessorSourceOperatorAdapter(operatorContext, this, strategy, reuseTableScanMappingId, spillEnabled, types, spillerFactory, spillerThreshold, consumerTableScanNodeCount); } @@ -668,7 +659,7 @@ public class ScanFilterAndProjectOmniOperator pageSourceProvider, cursorProcessor.get(), pageProcessor.get(), table, columns, dynamicFilter, types, minOutputPageSize, minOutputPageRowCount, this.tableScanNodeOptional, this.stateStoreProviderOptional, queryIdOptional, metadataOptional, - dynamicFilterCacheManagerOptional, vecAllocator, sourceTypes, context); + dynamicFilterCacheManagerOptional, sourceTypes, context); } @Override diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/TopNOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/TopNOmniOperator.java index f5a6cefac31e4e0909404a9f1c26a00b9c298d17..9097ca7aeec8ff63d62f4b8fa324c27294104dc2 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/TopNOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/TopNOmniOperator.java @@ -26,12 +26,10 @@ import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.Type; import nova.hetu.olk.tool.BlockUtils; import nova.hetu.olk.tool.OperatorUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; import nova.hetu.olk.tool.VecBatchToPageIterator; import nova.hetu.omniruntime.operator.OmniOperator; import nova.hetu.omniruntime.operator.topn.OmniTopNOperatorFactory; import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; import java.util.Iterator; @@ -125,7 +123,7 @@ public class TopNOmniOperator checkState(!finishing, "Operator is already finishing"); requireNonNull(page, "page is null"); - VecBatch vecBatch = buildVecBatch(omniOperator.getVecAllocator(), page, this); + VecBatch vecBatch = buildVecBatch(page, this); omniOperator.addInput(vecBatch); } @@ -222,11 +220,9 @@ public class TopNOmniOperator @Override public Operator createOperator(DriverContext driverContext) { - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, TopNOmniOperator.class); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, TopNOmniOperator.class.getSimpleName()); - OmniOperator omniOperator = omniTopNOperatorFactory.createOperator(vecAllocator); + OmniOperator omniOperator = omniTopNOperatorFactory.createOperator(); return new TopNOmniOperator(operatorContext, omniOperator, topN); } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/WindowOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/WindowOmniOperator.java index e52973fb4a3a297287b57f06f61faa1e4080738f..c23de10546e2a707c41a7311ccf4d367d25a324d 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/WindowOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/WindowOmniOperator.java @@ -31,7 +31,6 @@ import io.prestosql.spi.sql.expression.Types; import io.prestosql.spi.type.Type; import nova.hetu.olk.tool.BlockUtils; import nova.hetu.olk.tool.OperatorUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; import nova.hetu.olk.tool.VecBatchToPageIterator; import nova.hetu.omniruntime.constants.FunctionType; import nova.hetu.omniruntime.constants.OmniWindowFrameBoundType; @@ -39,7 +38,6 @@ import nova.hetu.omniruntime.constants.OmniWindowFrameType; import nova.hetu.omniruntime.operator.OmniOperator; import nova.hetu.omniruntime.operator.window.OmniWindowOperatorFactory; import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; import java.util.Iterator; @@ -126,7 +124,7 @@ public class WindowOmniOperator checkState(!finishing, "Operator is already finishing"); requireNonNull(page, "page is null"); - VecBatch vecBatch = buildVecBatch(omniOperator.getVecAllocator(), page, this); + VecBatch vecBatch = buildVecBatch(page, this); omniOperator.addInput(vecBatch); } @@ -410,11 +408,9 @@ public class WindowOmniOperator @Override public Operator createOperator(DriverContext driverContext) { - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, WindowOmniOperator.class); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, WindowOmniOperator.class.getSimpleName()); - OmniOperator omniOperator = omniWindowOperatorFactory.createOperator(vecAllocator); + OmniOperator omniOperator = omniWindowOperatorFactory.createOperator(); return new WindowOmniOperator(operatorContext, omniOperator); } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/FilterAndProjectOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/FilterAndProjectOmniOperator.java index 0cab5b73e4218890cfc4c098e17312ed7f5b9865..b8a38ed32f0208d16a4ac22854df943438851cb9 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/FilterAndProjectOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/FilterAndProjectOmniOperator.java @@ -15,6 +15,7 @@ package nova.hetu.olk.operator.filterandproject; +import com.esotericsoftware.minlog.Log; import com.google.common.collect.ImmutableList; import io.airlift.units.DataSize; import io.prestosql.memory.context.LocalMemoryContext; @@ -25,10 +26,10 @@ import io.prestosql.operator.OperatorFactory; import io.prestosql.operator.project.PageProcessor; import io.prestosql.spi.Page; import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.type.StandardTypes; import io.prestosql.spi.type.Type; +import io.prestosql.spi.type.TypeSignature; import nova.hetu.olk.operator.AbstractOmniOperatorFactory; -import nova.hetu.olk.tool.VecAllocatorHelper; -import nova.hetu.omniruntime.vector.VecAllocator; import java.util.List; import java.util.function.Supplier; @@ -144,19 +145,16 @@ public class FilterAndProjectOmniOperator this.minOutputPageSize = requireNonNull(minOutputPageSize, "minOutputPageSize is null"); this.minOutputPageRowCount = minOutputPageRowCount; this.sourceTypes = sourceTypes; - checkDataTypes(this.sourceTypes); } @Override public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); - VecAllocator vecAllocator = VecAllocatorHelper.createOperatorLevelAllocator(driverContext, - VecAllocator.UNLIMIT, FilterAndProjectOmniOperator.class); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, FilterAndProjectOmniOperator.class.getSimpleName()); return new FilterAndProjectOmniOperator(operatorContext, processor.get(), - new OmniMergingPageOutput(types, minOutputPageSize.toBytes(), minOutputPageRowCount, vecAllocator)); + new OmniMergingPageOutput(types, minOutputPageSize.toBytes(), minOutputPageRowCount)); } @Override @@ -171,5 +169,27 @@ public class FilterAndProjectOmniOperator return new FilterAndProjectOmniOperatorFactory(operatorId, planNodeId, processor, types, minOutputPageSize, minOutputPageRowCount, sourceTypes); } + + public static boolean checkType(Type type) + { + TypeSignature signature = type.getTypeSignature(); + String base = signature.getBase(); + + switch (base) { + case StandardTypes.INTEGER: + case StandardTypes.SMALLINT: + case StandardTypes.BIGINT: + case StandardTypes.DOUBLE: + case StandardTypes.BOOLEAN: + case StandardTypes.VARCHAR: + case StandardTypes.CHAR: + case StandardTypes.DECIMAL: + case StandardTypes.DATE: + return true; + default: + Log.warn("Not support datatype: " + base); + return false; + } + } } } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniExpressionCompiler.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniExpressionCompiler.java index 1e538fff9ffc096df05cd3e5e0f03f9e4d147588..ec0ca805176c07d1e9ab95672cd026210dbbc252 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniExpressionCompiler.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniExpressionCompiler.java @@ -32,8 +32,6 @@ import io.prestosql.sql.gen.ExpressionProfiler; import io.prestosql.sql.gen.PageFunctionCompiler; import io.prestosql.sql.relational.RowExpressionDeterminismEvaluator; import nova.hetu.olk.OmniLocalExecutionPlanner.OmniLocalExecutionPlanContext; -import nova.hetu.omniruntime.vector.VecAllocator; -import nova.hetu.omniruntime.vector.VecAllocatorFactory; import javax.inject.Inject; @@ -124,7 +122,6 @@ public class OmniExpressionCompiler List projections, Optional classNameSuffix, OptionalInt initialBatchSize, List inputTypes, TaskId taskId, OmniLocalExecutionPlanContext context) { - VecAllocator vecAllocator = VecAllocatorFactory.get(taskId.toString()); Optional pageFilter; if (filter.isPresent()) { OmniPageFilter omniPageFilter = filterCache @@ -143,8 +140,15 @@ public class OmniExpressionCompiler return null; } - return () -> new OmniPageProcessor(vecAllocator, pageFilter, proj, initialBatchSize, new ExpressionProfiler(), - context); + for (int i = 0; i < inputTypes.size(); i++) { + boolean isSupported = FilterAndProjectOmniOperator.FilterAndProjectOmniOperatorFactory + .checkType(inputTypes.get(i)); + if (!isSupported) { + return null; + } + } + + return () -> new OmniPageProcessor(pageFilter, proj, initialBatchSize, new ExpressionProfiler(), context); } private static final class FilterCacheKey diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniMergePages.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniMergePages.java index 84ee9a0f4152d35f82579d64225cdfbf48f95aa6..027bbd5090a42b3bd7556a7a2adb736d96dbcc99 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniMergePages.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniMergePages.java @@ -26,7 +26,6 @@ import nova.hetu.olk.OmniLocalExecutionPlanner; import nova.hetu.olk.tool.BlockUtils; import nova.hetu.olk.tool.VecBatchToPageIterator; import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; import org.glassfish.jersey.internal.guava.Lists; @@ -42,7 +41,6 @@ import static java.util.Objects.requireNonNull; import static nova.hetu.olk.tool.OperatorUtils.createBlankVectors; import static nova.hetu.olk.tool.OperatorUtils.merge; import static nova.hetu.olk.tool.OperatorUtils.toDataTypes; -import static nova.hetu.olk.tool.VecAllocatorHelper.getVecAllocatorFromBlocks; /** * The type Omni merge pages. @@ -92,8 +90,6 @@ public class OmniMergePages private int maxPageSizeInBytes; - private VecAllocator vecAllocator; - /** * Instantiates a new Omni merge pages. * @@ -199,10 +195,6 @@ public class OmniMergePages */ public void appendPage(Page page) { - // VecAllocator is only created once - if (this.vecAllocator == null) { - this.vecAllocator = getVecAllocatorFromBlocks(page.getBlocks()); - } pages.add(page); totalPositions += page.getPositionCount(); currentPageSizeInBytes = currentPageSizeInBytes + page.getSizeInBytes(); @@ -226,9 +218,9 @@ public class OmniMergePages */ public Page flush() { - VecBatch mergeResult = new VecBatch(createBlankVectors(vecAllocator, dataTypes, totalPositions), + VecBatch mergeResult = new VecBatch(createBlankVectors(dataTypes, totalPositions), totalPositions); - merge(mergeResult, pages, vecAllocator); + merge(mergeResult, pages); Page finalPage = new VecBatchToPageIterator(ImmutableList.of(mergeResult).iterator()).next(); currentPageSizeInBytes = 0; retainedSizeInBytes = 0; diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniMergingPageOutput.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniMergingPageOutput.java index 28ce75ccb0ffa8615c6722615515e41792ab511f..fc193ca9e86374dabe3242a8161e70b49805f787 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniMergingPageOutput.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniMergingPageOutput.java @@ -22,7 +22,6 @@ import nova.hetu.olk.tool.BlockUtils; import nova.hetu.olk.tool.OperatorUtils; import nova.hetu.olk.tool.VecBatchToPageIterator; import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; import org.openjdk.jol.info.ClassLayout; @@ -40,7 +39,6 @@ import static io.prestosql.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_ import static java.util.Objects.requireNonNull; import static nova.hetu.olk.tool.OperatorUtils.createBlankVectors; import static nova.hetu.olk.tool.OperatorUtils.merge; -import static nova.hetu.olk.tool.VecAllocatorHelper.getVecAllocatorFromBlocks; /** * This class is intended to be used right after the PageProcessor to ensure @@ -86,20 +84,12 @@ public class OmniMergingPageOutput private int totalPositions; private long currentPageSizeInBytes; private long retainedSizeInBytes; - private VecAllocator vecAllocator; public OmniMergingPageOutput(Iterable types, long minPageSizeInBytes, int minRowCount) { this(types, minPageSizeInBytes, minRowCount, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); } - public OmniMergingPageOutput(Iterable types, long minPageSizeInBytes, int minRowCount, - VecAllocator vecAllocator) - { - this(types, minPageSizeInBytes, minRowCount, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); - this.vecAllocator = vecAllocator; - } - public OmniMergingPageOutput(Iterable types, long minPageSizeInBytes, int minRowCount, int maxPageSizeInBytes) { @@ -213,10 +203,6 @@ public class OmniMergingPageOutput private void buffer(Page page) { - // VecAllocator is only created once - if (this.vecAllocator == null) { - this.vecAllocator = getVecAllocatorFromBlocks(page.getBlocks()); - } totalPositions += page.getPositionCount(); bufferedPages.add(page); currentPageSizeInBytes = currentPageSizeInBytes + page.getSizeInBytes(); @@ -233,9 +219,9 @@ public class OmniMergingPageOutput return; } - VecBatch resultVecBatch = new VecBatch(createBlankVectors(vecAllocator, dataTypes, totalPositions), + VecBatch resultVecBatch = new VecBatch(createBlankVectors(dataTypes, totalPositions), totalPositions); - merge(resultVecBatch, bufferedPages, vecAllocator); + merge(resultVecBatch, bufferedPages); outputQueue.add(new VecBatchToPageIterator(ImmutableList.of(resultVecBatch).iterator()).next()); // reset buffers diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniPageFilter.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniPageFilter.java index 2db0267f891c2ef8ec872b095d3dd487d2d59ea6..34a49bd78e096f82f493fc2e645be9fc3ac5220f 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniPageFilter.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniPageFilter.java @@ -26,7 +26,6 @@ import nova.hetu.omniruntime.operator.OmniOperator; import nova.hetu.omniruntime.operator.filter.OmniFilterAndProjectOperatorFactory; import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.utils.OmniRuntimeException; -import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; import java.util.Iterator; @@ -110,12 +109,11 @@ public class OmniPageFilter /** * Gets operator. * - * @param vecAllocator vector allocator * @return the operator */ - public OmniPageFilterOperator getOperator(VecAllocator vecAllocator) + public OmniPageFilterOperator getOperator() { - return new OmniPageFilterOperator(operatorFactory.createOperator(vecAllocator), inputChannels, inputTypes, + return new OmniPageFilterOperator(operatorFactory.createOperator(), inputChannels, inputTypes, projects); } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniPageProcessor.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniPageProcessor.java index 7857f48e6516045bff495da7955b53e0849fa902..aa88c681e15f6a8a6bc94f49c31fab981189aa07 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniPageProcessor.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniPageProcessor.java @@ -33,7 +33,6 @@ import nova.hetu.olk.tool.BlockUtils; import nova.hetu.olk.tool.VecBatchToPageIterator; import nova.hetu.omniruntime.operator.OmniOperator; import nova.hetu.omniruntime.utils.OmniRuntimeException; -import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; import javax.annotation.concurrent.NotThreadSafe; @@ -62,8 +61,6 @@ public class OmniPageProcessor { private final OmniProjection projection; - private final VecAllocator vecAllocator; - private final OmniLocalExecutionPlanContext context; private Optional omniPageFilterOperator = Optional.empty(); @@ -77,21 +74,20 @@ public class OmniPageProcessor * @param initialBatchSize the initial batch size * @param expressionProfiler the expression profiler */ - public OmniPageProcessor(VecAllocator vecAllocator, Optional filter, OmniProjection proj, + public OmniPageProcessor(Optional filter, OmniProjection proj, OptionalInt initialBatchSize, ExpressionProfiler expressionProfiler, OmniLocalExecutionPlanContext context) { super(filter, Collections.emptyList(), initialBatchSize, expressionProfiler); - this.vecAllocator = vecAllocator; this.context = context; this.projection = requireNonNull(proj, "projection is null"); if (filter.isPresent()) { PageFilter pageFilter = filter.get(); - this.omniPageFilterOperator = Optional.of(((OmniPageFilter) pageFilter).getOperator(vecAllocator)); + this.omniPageFilterOperator = Optional.of(((OmniPageFilter) pageFilter).getOperator()); } else { - this.omniProjectionOperator = Optional.of(projection.getFactory().createOperator(vecAllocator)); + this.omniProjectionOperator = Optional.of(projection.getFactory().createOperator()); } } @@ -143,7 +139,7 @@ public class OmniPageProcessor } Page preloadPage = preloadNeedFilterLazyBlock(page); - VecBatch inputVecBatch = buildVecBatch(vecAllocator, preloadPage, this); + VecBatch inputVecBatch = buildVecBatch(preloadPage, this); if (omniPageFilterOperator.isPresent()) { VecBatch filteredVecBatch = omniPageFilterOperator.get().filterAndProject(inputVecBatch); if (filteredVecBatch == null) { @@ -158,15 +154,13 @@ public class OmniPageProcessor } } - return WorkProcessor.create(new OmniProjectSelectedPositions(vecAllocator, session, yieldSignal, memoryContext, + return WorkProcessor.create(new OmniProjectSelectedPositions(session, yieldSignal, memoryContext, inputVecBatch, positionsRange(0, inputVecBatch.getRowCount()), omniProjectionOperator.get())); } private class OmniProjectSelectedPositions implements WorkProcessor.Process { - private final VecAllocator vecAllocator; - private final ConnectorSession session; private final DriverYieldSignal yieldSignal; @@ -183,18 +177,16 @@ public class OmniPageProcessor /** * Instantiates a new Omni project selected positions. * - * @param vecAllocator vector allocator * @param session the session * @param yieldSignal the yield signal * @param memoryContext the memory context * @param vecBatch the page * @param selectedPositions the selected positions */ - public OmniProjectSelectedPositions(VecAllocator vecAllocator, ConnectorSession session, - DriverYieldSignal yieldSignal, LocalMemoryContext memoryContext, VecBatch vecBatch, + public OmniProjectSelectedPositions(ConnectorSession session, DriverYieldSignal yieldSignal, + LocalMemoryContext memoryContext, VecBatch vecBatch, SelectedPositions selectedPositions, OmniOperator omniProjectionOperator) { - this.vecAllocator = vecAllocator; this.omniProjectionOperator = omniProjectionOperator; this.session = session; this.yieldSignal = yieldSignal; diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniRowExpressionUtil.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniRowExpressionUtil.java index 45bc00f45b6ef491482b297aef067733bafbac9c..ec373aef60b4aefa27306e5b8cd0e32a4a2e3a0e 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniRowExpressionUtil.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/filterandproject/OmniRowExpressionUtil.java @@ -366,7 +366,7 @@ public class OmniRowExpressionUtil Optional likeTranslatedFilter = Optional .of(new CallExpression(((CallExpression) translatedExpr.get()).getDisplayName().toUpperCase(Locale.ROOT), ((CallExpression) translatedExpr.get()).getFunctionHandle(), - translatedExpr.get().getType(), newArgs)); + ((CallExpression) translatedExpr.get()).getType(), newArgs)); return likeTranslatedFilter; } } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/localexchange/OmniPartitioningExchanger.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/localexchange/OmniPartitioningExchanger.java index 13df4b72a27fe3c47e7af7a7473f5a0a9a0287d4..5e0b79f0cf69f798fb87b9aac7769b09e48877dd 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/localexchange/OmniPartitioningExchanger.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/localexchange/OmniPartitioningExchanger.java @@ -33,8 +33,6 @@ import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntList; import nova.hetu.olk.operator.filterandproject.OmniMergingPageOutput; import nova.hetu.olk.tool.BlockUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; -import nova.hetu.omniruntime.vector.VecAllocator; import java.util.Iterator; import java.util.List; @@ -81,9 +79,8 @@ public class OmniPartitioningExchanger for (int i = 0; i < partitionAssignments.length; i++) { partitionAssignments[i] = new IntArrayList(); } - VecAllocator allocator = VecAllocatorHelper.createOperatorLevelAllocator(taskContext, - VecAllocator.UNLIMIT, VecAllocatorHelper.DEFAULT_RESERVATION, OmniPartitioningExchanger.class); - mergingPageOutput = new OmniMergingPageOutput(types, 128000, 256, allocator); + + mergingPageOutput = new OmniMergingPageOutput(types, 128000, 256); } private Iterator> createPagesIterator(Page... pages) diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/BlockUtils.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/BlockUtils.java index 3f43fe16acc24a50bfe5878a0e180946230b05ca..fb900bf0b1549dd25e5f73d4513183d0ed313ce7 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/BlockUtils.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/BlockUtils.java @@ -17,13 +17,6 @@ package nova.hetu.olk.tool; import io.prestosql.spi.Page; import io.prestosql.spi.block.Block; -import nova.hetu.omniruntime.vector.BooleanVec; -import nova.hetu.omniruntime.vector.Decimal128Vec; -import nova.hetu.omniruntime.vector.DoubleVec; -import nova.hetu.omniruntime.vector.IntVec; -import nova.hetu.omniruntime.vector.LongVec; -import nova.hetu.omniruntime.vector.ShortVec; -import nova.hetu.omniruntime.vector.VarcharVec; /** * The type Block utils. @@ -36,139 +29,18 @@ public class BlockUtils { } - /** - * Compact vec boolean vec. - * - * @param vec the vec - * @param index the index - * @param length the length - * @return the boolean vec - */ - public static BooleanVec compactVec(BooleanVec vec, int index, int length) - { - if (index == 0 && length == vec.getSize() && vec.getOffset() == 0) { - return vec; - } - BooleanVec newValues = vec.copyRegion(index, length); - vec.close(); - return newValues; - } - - /** - * Compact vec int vec. - * - * @param vec the vec - * @param index the index - * @param length the length - * @return the int vec - */ - public static IntVec compactVec(IntVec vec, int index, int length) - { - if (index == 0 && length == vec.getSize() && vec.getOffset() == 0) { - return vec; - } - IntVec newValues = vec.copyRegion(index, length); - vec.close(); - return newValues; - } - - /** - * Compact vec short vec. - * - * @param vec the vec - * @param index the index - * @param length the length - * @return the short vec - */ - public static ShortVec compactVec(ShortVec vec, int index, int length) - { - if (index == 0 && length == vec.getSize() && vec.getOffset() == 0) { - return vec; - } - ShortVec newValues = vec.copyRegion(index, length); - vec.close(); - return newValues; - } - - /** - * Compact vec long vec. - * - * @param vec the vec - * @param index the index - * @param length the length - * @return the long vec - */ - public static LongVec compactVec(LongVec vec, int index, int length) - { - if (index == 0 && length == vec.getSize() && vec.getOffset() == 0) { - return vec; - } - LongVec newValues = vec.copyRegion(index, length); - vec.close(); - return newValues; - } - - /** - * Compact vec double vec. - * - * @param vec the vec - * @param index the index - * @param length the length - * @return the double vec - */ - public static DoubleVec compactVec(DoubleVec vec, int index, int length) - { - if (index == 0 && length == vec.getSize() && vec.getOffset() == 0) { - return vec; - } - DoubleVec newValues = vec.copyRegion(index, length); - vec.close(); - return newValues; - } - - /** - * Compact vec varchar vec. - * - * @param vec the vec - * @param index the index - * @param length the length - * @return the varchar vec - */ - public static VarcharVec compactVec(VarcharVec vec, int index, int length) - { - if (index == 0 && length == vec.getSize() && vec.getOffset() == 0) { - return vec; - } - VarcharVec newValues = vec.copyRegion(index, length); - vec.close(); - return newValues; - } - - /** - * Compact vec decimal 128 vec. - * - * @param vec the vec - * @param index the index - * @param length the length - * @return the decimal 128 vec - */ - public static Decimal128Vec compactVec(Decimal128Vec vec, int index, int length) - { - if (index == 0 && length == vec.getSize() && vec.getOffset() == 0) { - return vec; - } - Decimal128Vec newValues = vec.copyRegion(index, length); - vec.close(); - return newValues; - } - public static void freePage(Page page) { + // release native vector Block[] blocks = page.getBlocks(); if (blocks != null) { for (Block block : blocks) { block.close(); } } + // only release vecBatch if page belong to OmniPage + if (page instanceof OmniPage) { + ((OmniPage) page).close(); + } } } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OmniPage.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OmniPage.java new file mode 100644 index 0000000000000000000000000000000000000000..ff3d6477ce95f5bb7fcd2c21f1dd3923ed5cb770 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OmniPage.java @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.tool; + +import io.prestosql.spi.Page; +import io.prestosql.spi.block.Block; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import java.util.ArrayList; +import java.util.List; + +public class OmniPage + extends Page +{ + private VecBatch vecBatch; + + public OmniPage(int positionCount, VecBatch vecBatch, Block... blocks) + { + super(positionCount, blocks); + this.vecBatch = vecBatch; + } + + public OmniPage(Block... blocks) + { + super(blocks[0].getPositionCount(), blocks); + List vecs = new ArrayList<>(); + for (Block block : blocks) { + vecs.add((Vec) block.getValues()); + } + this.vecBatch = new VecBatch(vecs); + } + + public VecBatch getVecBatch() + { + return vecBatch; + } + + public void close() + { + if (vecBatch != null) { + vecBatch.close(); + vecBatch = null; + } + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java index 59b556c781f79d4ad07e69eb4c2bbba07162e791..a95e2e9d88425d8bc2bca1b3f17d1c467fee37e8 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java @@ -18,7 +18,6 @@ package nova.hetu.olk.tool; import com.google.common.primitives.Ints; import io.airlift.log.Logger; import io.airlift.slice.Slice; -import io.airlift.slice.Slices; import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.StandardErrorCode; @@ -28,7 +27,6 @@ import io.prestosql.spi.block.ByteArrayBlock; import io.prestosql.spi.block.DictionaryBlock; import io.prestosql.spi.block.Int128ArrayBlock; import io.prestosql.spi.block.IntArrayBlock; -import io.prestosql.spi.block.LazyBlock; import io.prestosql.spi.block.LongArrayBlock; import io.prestosql.spi.block.RowBlock; import io.prestosql.spi.block.RunLengthEncodedBlock; @@ -73,7 +71,6 @@ import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.ShortVec; import nova.hetu.omniruntime.vector.VarcharVec; import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; import java.util.ArrayList; @@ -81,12 +78,11 @@ import java.util.Arrays; import java.util.List; import java.util.Optional; -import static com.google.common.base.Preconditions.checkArgument; import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.prestosql.spi.type.Decimals.MAX_SHORT_PRECISION; +import static io.prestosql.spi.type.DoubleType.DOUBLE; import static java.lang.Double.doubleToLongBits; import static java.lang.Double.longBitsToDouble; -import static javassist.bytecode.StackMap.DOUBLE; /** * The type Operator utils. @@ -96,7 +92,6 @@ import static javassist.bytecode.StackMap.DOUBLE; public final class OperatorUtils { private static final Logger log = Logger.get(OperatorUtils.class); - private static final int VARCHARVEC_INIT_CAPACITY_PER_POSITION = 200; private OperatorUtils() { @@ -232,12 +227,11 @@ public final class OperatorUtils /** * Create blank vectors for given size and types. * - * @param vecAllocator VecAllocator used to create vectors * @param dataTypes data types * @param totalPositions Size for all the vectors * @return List contains blank vectors */ - public static List createBlankVectors(VecAllocator vecAllocator, DataType[] dataTypes, int totalPositions) + public static List createBlankVectors(DataType[] dataTypes, int totalPositions) { List vecsResult = new ArrayList<>(); for (int i = 0; i < dataTypes.length; i++) { @@ -245,33 +239,32 @@ public final class OperatorUtils switch (type.getId()) { case OMNI_INT: case OMNI_DATE32: - vecsResult.add(new IntVec(vecAllocator, totalPositions)); + vecsResult.add(new IntVec(totalPositions)); break; case OMNI_SHORT: - vecsResult.add(new ShortVec(vecAllocator, totalPositions)); + vecsResult.add(new ShortVec(totalPositions)); break; case OMNI_LONG: case OMNI_DECIMAL64: - vecsResult.add(new LongVec(vecAllocator, totalPositions)); + vecsResult.add(new LongVec(totalPositions)); break; case OMNI_DOUBLE: - vecsResult.add(new DoubleVec(vecAllocator, totalPositions)); + vecsResult.add(new DoubleVec(totalPositions)); break; case OMNI_BOOLEAN: - vecsResult.add(new BooleanVec(vecAllocator, totalPositions)); + vecsResult.add(new BooleanVec(totalPositions)); break; case OMNI_VARCHAR: case OMNI_CHAR: // Blank varcharVec uses 200 bytes for initialization in each position, // and it will automatically expand capacity if additional capacity is required. - vecsResult.add(new VarcharVec(vecAllocator, totalPositions * VARCHARVEC_INIT_CAPACITY_PER_POSITION, - totalPositions)); + vecsResult.add(new VarcharVec(totalPositions)); break; case OMNI_DECIMAL128: - vecsResult.add(new Decimal128Vec(vecAllocator, totalPositions)); + vecsResult.add(new Decimal128Vec(totalPositions)); break; case OMNI_CONTAINER: - vecsResult.add(createBlankContainerVector(vecAllocator, type, totalPositions)); + vecsResult.add(createBlankContainerVector(type, totalPositions)); break; default: throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Not support data type " + type); @@ -280,35 +273,32 @@ public final class OperatorUtils return vecsResult; } - private static ContainerVec createBlankContainerVector(VecAllocator vecAllocator, DataType type, - int totalPositions) + private static ContainerVec createBlankContainerVector(DataType type, int totalPositions) { if (!(type instanceof ContainerDataType)) { throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "type is not container type:" + type); } ContainerDataType containerDataType = (ContainerDataType) type; - List fieldVecs = createBlankVectors(vecAllocator, containerDataType.getFieldTypes(), totalPositions); + List fieldVecs = createBlankVectors(containerDataType.getFieldTypes(), totalPositions); long[] nativeVec = new long[fieldVecs.size()]; for (int i = 0; i < fieldVecs.size(); i++) { nativeVec[i] = fieldVecs.get(i).getNativeVector(); } - return new ContainerVec(vecAllocator, containerDataType.size(), totalPositions, nativeVec, - containerDataType.getFieldTypes()); + return new ContainerVec(containerDataType.size(), totalPositions, nativeVec, containerDataType.getFieldTypes()); } /** * Transfer to off heap pages list. * - * @param vecAllocator vector allocator * @param pages the pages * @return the list */ - public static List transferToOffHeapPages(VecAllocator vecAllocator, List pages) + public static List transferToOffHeapPages(List pages) { List offHeapInput = new ArrayList<>(); for (Page page : pages) { - Block[] blocks = getOffHeapBlocks(vecAllocator, page.getBlocks(), null); - offHeapInput.add(new Page(blocks)); + Block[] blocks = getOffHeapBlocks(page.getBlocks(), null); + offHeapInput.add(new OmniPage(blocks)); } return offHeapInput; } @@ -316,47 +306,45 @@ public final class OperatorUtils /** * Transfer to off heap pages page. * - * @param vecAllocator vector allocator * @param page the page * @return the page */ - public static Page transferToOffHeapPages(VecAllocator vecAllocator, Page page) + public static Page transferToOffHeapPages(Page page) { if (page.getBlocks().length == 0) { return page; } - Block[] blocks = getOffHeapBlocks(vecAllocator, page.getBlocks(), null); - return new Page(blocks); + Block[] blocks = getOffHeapBlocks(page.getBlocks(), null); + return new OmniPage(blocks); } /** * Transfer to off heap pages page with types. * - * @param vecAllocator vector allocator * @param page the page * @param blockTypes types * @return the page */ - public static Page transferToOffHeapPages(VecAllocator vecAllocator, Page page, List blockTypes) + public static Page transferToOffHeapPages(Page page, List blockTypes) { if (page.getBlocks().length == 0) { return page; } - Block[] blocks = getOffHeapBlocks(vecAllocator, page.getBlocks(), blockTypes); - return new Page(blocks); + Block[] blocks = getOffHeapBlocks(page.getBlocks(), blockTypes); + return new OmniPage(blocks); } - private static Block[] getOffHeapBlocks(VecAllocator vecAllocator, Block[] blocks, List blockTypes) + private static Block[] getOffHeapBlocks(Block[] blocks, List blockTypes) { Block[] res = new Block[blocks.length]; if (blockTypes == null || blockTypes.isEmpty()) { for (int i = 0; i < blocks.length; i++) { - res[i] = buildOffHeapBlock(vecAllocator, blocks[i]); + res[i] = buildOffHeapBlock(blocks[i]); } } else { for (int i = 0; i < blocks.length; i++) { - res[i] = buildOffHeapBlock(vecAllocator, blocks[i], blocks[i].getClass().getSimpleName(), + res[i] = buildOffHeapBlock(blocks[i], blocks[i].getClass().getSimpleName(), blocks[i].getPositionCount(), blockTypes.get(i)); } } @@ -366,13 +354,12 @@ public final class OperatorUtils /** * Gets off heap block. * - * @param vecAllocator vector allocator * @param block the block * @return the off heap block */ - public static Block buildOffHeapBlock(VecAllocator vecAllocator, Block block) + public static Block buildOffHeapBlock(Block block) { - return buildOffHeapBlock(vecAllocator, block, block.getClass().getSimpleName(), block.getPositionCount(), null); + return buildOffHeapBlock(block, block.getClass().getSimpleName(), block.getPositionCount(), null); } private static double[] transformLongArrayToDoubleArray(long[] values) @@ -410,8 +397,7 @@ public final class OperatorUtils } } - private static Block buildByteArrayOmniBlock(VecAllocator vecAllocator, Block block, int positionCount, - boolean isRLE) + private static Block buildByteArrayOmniBlock(Block block, int positionCount, boolean isRLE) { if (isRLE) { byte[] valueIsNull = null; @@ -425,18 +411,17 @@ public final class OperatorUtils Arrays.fill(bytes, (byte) block.get(0)); } } - return new ByteArrayOmniBlock(vecAllocator, 0, positionCount, valueIsNull, bytes); + return new ByteArrayOmniBlock(0, positionCount, valueIsNull, bytes); } else { boolean[] valueIsNull = block.getValueNulls(); int offset = block.getBlockOffset(); byte[] bytes = ((ByteArrayBlock) block).getValues(); - return new ByteArrayOmniBlock(vecAllocator, offset, positionCount, transformBooleanToByte(valueIsNull), bytes); + return new ByteArrayOmniBlock(offset, positionCount, transformBooleanToByte(valueIsNull), bytes); } } - private static Block buildIntArrayOmniBLock(VecAllocator vecAllocator, Block block, int positionCount, - boolean isRLE) + private static Block buildIntArrayOmniBLock(Block block, int positionCount, boolean isRLE) { if (isRLE) { byte[] valueIsNull = null; @@ -450,18 +435,17 @@ public final class OperatorUtils Arrays.fill(values, (int) block.get(0)); } } - return new IntArrayOmniBlock(vecAllocator, 0, positionCount, valueIsNull, values); + return new IntArrayOmniBlock(0, positionCount, valueIsNull, values); } else { boolean[] valueIsNull = block.getValueNulls(); int offset = block.getBlockOffset(); int[] values = ((IntArrayBlock) block).getValues(); - return new IntArrayOmniBlock(vecAllocator, offset, positionCount, transformBooleanToByte(valueIsNull), values); + return new IntArrayOmniBlock(offset, positionCount, transformBooleanToByte(valueIsNull), values); } } - private static Block buildShortArrayOmniBLock(VecAllocator vecAllocator, Block block, int positionCount, - boolean isRLE) + private static Block buildShortArrayOmniBLock(Block block, int positionCount, boolean isRLE) { if (isRLE) { byte[] valueIsNull = null; @@ -475,7 +459,7 @@ public final class OperatorUtils Arrays.fill(values, (short) block.get(0)); } } - return new ShortArrayOmniBlock(vecAllocator, 0, positionCount, valueIsNull, values); + return new ShortArrayOmniBlock(0, positionCount, valueIsNull, values); } else { ShortArrayBlock shortArrayBlock = (ShortArrayBlock) block; @@ -493,15 +477,14 @@ public final class OperatorUtils } } if (hasNull) { - return new ShortArrayOmniBlock(vecAllocator, 0, positionCount, valueIsNull, values); + return new ShortArrayOmniBlock(0, positionCount, valueIsNull, values); } } - return new ShortArrayOmniBlock(vecAllocator, 0, positionCount, null, values); + return new ShortArrayOmniBlock(0, positionCount, null, values); } } - private static Block buildLongArrayOmniBLock(VecAllocator vecAllocator, Block block, int positionCount, - boolean isRLE) + private static Block buildLongArrayOmniBLock(Block block, int positionCount, boolean isRLE) { if (isRLE) { byte[] valueIsNull = null; @@ -515,18 +498,17 @@ public final class OperatorUtils Arrays.fill(values, (long) block.get(0)); } } - return new LongArrayOmniBlock(vecAllocator, 0, positionCount, valueIsNull, values); + return new LongArrayOmniBlock(0, positionCount, valueIsNull, values); } else { boolean[] valueIsNull = block.getValueNulls(); int offset = block.getBlockOffset(); long[] values = ((LongArrayBlock) block).getValues(); - return new LongArrayOmniBlock(vecAllocator, offset, positionCount, transformBooleanToByte(valueIsNull), values); + return new LongArrayOmniBlock(offset, positionCount, transformBooleanToByte(valueIsNull), values); } } - private static Block buildDoubleArrayOmniBLock(VecAllocator vecAllocator, Block block, int positionCount, - boolean isRLE) + private static Block buildDoubleArrayOmniBLock(Block block, int positionCount, boolean isRLE) { if (isRLE) { byte[] valueIsNull = null; @@ -540,19 +522,18 @@ public final class OperatorUtils Arrays.fill(doubles, longBitsToDouble((long) block.get(0))); } } - return new DoubleArrayOmniBlock(vecAllocator, 0, positionCount, valueIsNull, doubles); + return new DoubleArrayOmniBlock(0, positionCount, valueIsNull, doubles); } else { boolean[] valueIsNull = block.getValueNulls(); int offset = block.getBlockOffset(); long[] values = ((LongArrayBlock) block).getValues(); double[] doubles = transformLongArrayToDoubleArray(values); - return new DoubleArrayOmniBlock(vecAllocator, offset, positionCount, transformBooleanToByte(valueIsNull), doubles); + return new DoubleArrayOmniBlock(offset, positionCount, transformBooleanToByte(valueIsNull), doubles); } } - private static Block buildInt128ArrayOmniBlock(VecAllocator vecAllocator, Block block, int positionCount, - boolean isRLE) + private static Block buildInt128ArrayOmniBlock(Block block, int positionCount, boolean isRLE) { if (isRLE) { byte[] valueIsNull = null; @@ -567,32 +548,30 @@ public final class OperatorUtils fillLongArray(val, longs); } } - return new Int128ArrayOmniBlock(vecAllocator, 0, positionCount, valueIsNull, longs); + return new Int128ArrayOmniBlock(0, positionCount, valueIsNull, longs); } else { boolean[] valueIsNull = block.getValueNulls(); int offset = block.getBlockOffset(); long[] longs = ((Int128ArrayBlock) block).getValues(); - return new Int128ArrayOmniBlock(vecAllocator, offset, positionCount, transformBooleanToByte(valueIsNull), longs); + return new Int128ArrayOmniBlock(offset, positionCount, transformBooleanToByte(valueIsNull), longs); } } - private static VariableWidthOmniBlock buildVariableWidthOmniBlock(VecAllocator vecAllocator, Block block, int positionCount, - boolean isRLE) + private static VariableWidthOmniBlock buildVariableWidthOmniBlock(Block block, int positionCount, boolean isRLE) { if (!isRLE) { int[] offsets = ((VariableWidthBlock) block).getOffsets(); int offset = block.getBlockOffset(); boolean[] valueIsNull = block.getValueNulls(); Slice slice = ((VariableWidthBlock) block).getRawSlice(0); - return new VariableWidthOmniBlock(vecAllocator, offset, positionCount, slice, offsets, + return new VariableWidthOmniBlock(offset, positionCount, slice, offsets, transformBooleanToByte(valueIsNull)); } else { AbstractVariableWidthBlock variableWidthBlock = (AbstractVariableWidthBlock) ((RunLengthEncodedBlock) block) .getValue(); - VarcharVec vec = new VarcharVec(vecAllocator, variableWidthBlock.getSliceLength(0) * positionCount, - positionCount); + VarcharVec vec = new VarcharVec(positionCount); for (int i = 0; i < positionCount; i++) { if (block.isNull(i)) { vec.setNull(i); @@ -605,11 +584,11 @@ public final class OperatorUtils } } - private static Block buildDictionaryOmniBlock(VecAllocator vecAllocator, Block inputBlock, Type blockType) + private static Block buildDictionaryOmniBlock(Block inputBlock, Type blockType) { DictionaryBlock dictionaryBlock = (DictionaryBlock) inputBlock; Block block = dictionaryBlock.getDictionary(); - Block omniDictionary = buildOffHeapBlock(vecAllocator, block, block.getClass().getSimpleName(), + Block omniDictionary = buildOffHeapBlock(block, block.getClass().getSimpleName(), block.getPositionCount(), blockType); Block dictionaryOmniBlock = new DictionaryOmniBlock(inputBlock.getPositionCount(), (Vec) omniDictionary.getValues(), dictionaryBlock.getIdsArray()); @@ -617,7 +596,7 @@ public final class OperatorUtils return dictionaryOmniBlock; } - private static Block buildRowOmniBlock(VecAllocator vecAllocator, Block block, int positionCount, Type blockType) + private static Block buildRowOmniBlock(Block block, int positionCount, Type blockType) { byte[] valueIsNull = new byte[positionCount]; RowBlock rowBlock = (RowBlock) block; @@ -626,27 +605,25 @@ public final class OperatorUtils valueIsNull[j] = Vec.NULL; } } - return RowOmniBlock.fromFieldBlocks(vecAllocator, rowBlock.getPositionCount(), Optional.of(valueIsNull), + return RowOmniBlock.fromFieldBlocks(rowBlock.getPositionCount(), Optional.of(valueIsNull), rowBlock.getRawFieldBlocks(), blockType, null); } /** * Gets off heap block. * - * @param vecAllocator vector allocator * @param block the block * @param type the actual block type, e.g. RunLengthEncodedBlock or * DictionaryBlock * @param positionCount the position count of the block * @return the off heap block */ - public static Block buildOffHeapBlock(VecAllocator vecAllocator, Block block, String type, int positionCount, - Type blockType) + public static Block buildOffHeapBlock(Block block, String type, int positionCount, Type blockType) { - return buildOffHeapBlock(vecAllocator, block, type, positionCount, false, blockType); + return buildOffHeapBlock(block, type, positionCount, false, blockType); } - private static Block buildOffHeapBlock(VecAllocator vecAllocator, Block block, String type, int positionCount, boolean isRLE, Type blockType) + private static Block buildOffHeapBlock(Block block, String type, int positionCount, boolean isRLE, Type blockType) { if (block.isExtensionBlock()) { return block; @@ -654,30 +631,28 @@ public final class OperatorUtils switch (type) { case "ByteArrayBlock": - return buildByteArrayOmniBlock(vecAllocator, block, positionCount, isRLE); + return buildByteArrayOmniBlock(block, positionCount, isRLE); case "IntArrayBlock": - return buildIntArrayOmniBLock(vecAllocator, block, positionCount, isRLE); + return buildIntArrayOmniBLock(block, positionCount, isRLE); case "ShortArrayBlock": - return buildShortArrayOmniBLock(vecAllocator, block, positionCount, isRLE); + return buildShortArrayOmniBLock(block, positionCount, isRLE); case "LongArrayBlock": if (blockType != null && blockType.equals(DOUBLE)) { - return buildDoubleArrayOmniBLock(vecAllocator, block, positionCount, isRLE); + return buildDoubleArrayOmniBLock(block, positionCount, isRLE); } - return buildLongArrayOmniBLock(vecAllocator, block, positionCount, isRLE); + return buildLongArrayOmniBLock(block, positionCount, isRLE); case "Int128ArrayBlock": - return buildInt128ArrayOmniBlock(vecAllocator, block, positionCount, isRLE); + return buildInt128ArrayOmniBlock(block, positionCount, isRLE); case "VariableWidthBlock": - return buildVariableWidthOmniBlock(vecAllocator, block, positionCount, isRLE); + return buildVariableWidthOmniBlock(block, positionCount, isRLE); case "DictionaryBlock": - return buildDictionaryOmniBlock(vecAllocator, block, blockType); + return buildDictionaryOmniBlock(block, blockType); case "RunLengthEncodedBlock": - return buildOffHeapBlock(vecAllocator, block, - ((RunLengthEncodedBlock) block).getValue().getClass().getSimpleName(), positionCount, true, - blockType); + return buildOffHeapBlock(block, ((RunLengthEncodedBlock) block).getValue().getClass().getSimpleName(), positionCount, true, blockType); case "LazyBlock": - return new LazyOmniBlock(vecAllocator, (LazyBlock) block, blockType); + return loadLazyBlock(block, blockType); case "RowBlock": - return buildRowOmniBlock(vecAllocator, block, positionCount, blockType); + return buildRowOmniBlock(block, positionCount, blockType); default: throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Not support block:" + type); } @@ -686,14 +661,13 @@ public final class OperatorUtils /** * Build a vector from block. * - * @param vecAllocator vector allocator. * @param block block * @return vector instance. */ - public static Vec buildVec(VecAllocator vecAllocator, Block block) + public static Vec buildVec(Block block) { if (!block.isExtensionBlock()) { - return (Vec) OperatorUtils.buildOffHeapBlock(vecAllocator, block).getValues(); + return (Vec) OperatorUtils.buildOffHeapBlock(block).getValues(); } else { return (Vec) block.getValues(); @@ -703,18 +677,21 @@ public final class OperatorUtils /** * Build a vector by {@link Block} * - * @param vecAllocator VecAllocator to create vectors * @param page the page * @param object the operator * @return the vec batch */ - public static VecBatch buildVecBatch(VecAllocator vecAllocator, Page page, Object object) + public static VecBatch buildVecBatch(Page page, Object object) { + if (page instanceof OmniPage) { + return ((OmniPage) page).getVecBatch(); + } + List vecList = new ArrayList<>(); for (int i = 0; i < page.getChannelCount(); i++) { Block block = page.getBlock(i); - Vec vec = buildVec(vecAllocator, block); + Vec vec = buildVec(block); vecList.add(vec); } @@ -728,7 +705,7 @@ public final class OperatorUtils * * @param resultVecBatch Stores final resulting vectors */ - public static void merge(VecBatch resultVecBatch, List pages, VecAllocator vecAllocator) + public static void merge(VecBatch resultVecBatch, List pages) { for (int channel = 0; channel < resultVecBatch.getVectorCount(); channel++) { int offset = 0; @@ -737,7 +714,7 @@ public final class OperatorUtils Block block = page.getBlock(channel); Vec src; if (!block.isExtensionBlock()) { - block = OperatorUtils.buildOffHeapBlock(vecAllocator, block); + block = OperatorUtils.buildOffHeapBlock(block); } src = (Vec) block.getValues(); Vec dest = resultVecBatch.getVector(channel); @@ -746,6 +723,12 @@ public final class OperatorUtils offset += positionCount; src.close(); } + + for (Page page : pages) { + if (page instanceof OmniPage) { + ((OmniPage) page).close(); + } + } } } @@ -793,7 +776,7 @@ public final class OperatorUtils break; case OMNI_DECIMAL128: rowBlocks[vecIdx] = new Int128ArrayOmniBlock(positionCount, - new Decimal128Vec(containerVec.getVector(vecIdx), dataType)); + new Decimal128Vec(containerVec.getVector(vecIdx))); break; default: throw new PrestoException(GENERIC_INTERNAL_ERROR, @@ -801,7 +784,7 @@ public final class OperatorUtils } } int[] fieldBlockOffsets = new int[positionCount + 1]; - byte[] nulls = containerVec.getRawValueNulls(); + byte[] nulls = transformBooleanToByte(containerVec.getValuesNulls(0, positionCount)); for (int position = 0; position < positionCount; position++) { fieldBlockOffsets[position + 1] = fieldBlockOffsets[position] + (nulls[position] == Vec.NULL ? 0 : 1); } @@ -812,7 +795,7 @@ public final class OperatorUtils /** * Transfer to on heap pages list. * - * @param pages the the off heap pages + * @param pages the off heap pages * @return the on heap page list */ public static List transferToOnHeapPages(List pages) @@ -862,31 +845,34 @@ public final class OperatorUtils private static Block buildOnHeapBlock(Block block, String type, int positionCount) { - checkArgument(block.isExtensionBlock(), "block should be omni block!"); - switch (type) { - case "ByteArrayOmniBlock": - return buildByteArrayBlock(block, positionCount); - case "IntArrayOmniBlock": - return buildIntArrayBLock(block, positionCount); - case "ShortArrayOmniBlock": - return buildShortArrayBLock(block, positionCount); - case "LongArrayOmniBlock": - return buildLongArrayBLock(block, positionCount); - case "DoubleArrayOmniBlock": - return buildDoubleArrayBLock(block, positionCount); - case "Int128ArrayOmniBlock": - return buildInt128ArrayBlock(block, positionCount); - case "VariableWidthOmniBlock": - return buildVariableWidthBlock(block, positionCount); - case "DictionaryOmniBlock": - return buildDictionaryBlock(block, positionCount); - case "LazyOmniBlock": - return buildLazyBlock(block); - case "RowOmniBlock": - return buildRowBlock(block, positionCount); - default: - throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Not support block:" + type); + // normal block is an extension Block, but dictionary block is an olk block. + if (block.isExtensionBlock()) { + switch (type) { + case "ByteArrayOmniBlock": + return buildByteArrayBlock(block, positionCount); + case "IntArrayOmniBlock": + return buildIntArrayBLock(block, positionCount); + case "ShortArrayOmniBlock": + return buildShortArrayBLock(block, positionCount); + case "LongArrayOmniBlock": + return buildLongArrayBLock(block, positionCount); + case "DoubleArrayOmniBlock": + return buildDoubleArrayBLock(block, positionCount); + case "Int128ArrayOmniBlock": + return buildInt128ArrayBlock(block, positionCount); + case "VariableWidthOmniBlock": + return buildVariableWidthBlock(block, positionCount); + case "DictionaryOmniBlock": + return buildDictionaryBlock(block, positionCount); + case "LazyOmniBlock": + return buildLazyBlock(block); + case "RowOmniBlock": + return buildRowBlock(block, positionCount); + default: + throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Not support block:" + type); + } } + return block; } private static Block buildRowBlock(Block block, int positionCount) @@ -908,21 +894,21 @@ public final class OperatorUtils private static Block buildDictionaryBlock(Block block, int positionCount) { DictionaryVec dictionaryVec = (DictionaryVec) block.getValues(); - int[] newIds = dictionaryVec.getIds(positionCount); Block dictionary = buildOnHeapBlock(((DictionaryOmniBlock) block).getDictionary()); - return new DictionaryBlock(dictionary, newIds); + return new DictionaryBlock(dictionary, ((DictionaryOmniBlock) block).getIds()); } private static Block buildVariableWidthBlock(Block block, int positionCount) { + VariableWidthOmniBlock offHeapBlock = (VariableWidthOmniBlock) block; + Slice slice = offHeapBlock.getRawSlice(0); + int[] offsets = new int[positionCount + 1]; + for (int i = 0; i < positionCount; i++) { + offsets[i + 1] = offsets[i] + offHeapBlock.getSliceLength(i); + } VarcharVec varcharVec = (VarcharVec) block.getValues(); - int[] offsets = varcharVec.getValueOffset(0, positionCount); - int startOffset = varcharVec.getValueOffset(0); - int endOffset = varcharVec.getValueOffset(positionCount); - byte[] data = varcharVec.getData(startOffset, endOffset - startOffset); - Slice slice = Slices.wrappedBuffer(data); return new VariableWidthBlock(positionCount, slice, offsets, - varcharVec.hasNullValue() + varcharVec.hasNull() ? Optional.of(varcharVec.getValuesNulls(0, positionCount)) : Optional.empty()); } @@ -971,7 +957,14 @@ public final class OperatorUtils private static Block buildByteArrayBlock(Block block, int positionCount) { BooleanVec booleanVec = (BooleanVec) block.getValues(); - byte[] bytes = booleanVec.getValuesBuf().getBytes(booleanVec.getOffset(), positionCount); + byte[] bytes = booleanVec.getValuesBuf().getBytes(0, positionCount); return new ByteArrayBlock(positionCount, Optional.of(booleanVec.getValuesNulls(0, positionCount)), bytes); } + + private static Block loadLazyBlock(Block lazyBlock, Type blockType) + { + Block block = lazyBlock.getLoadedBlock(); + return buildOffHeapBlock(block, block.getClass().getSimpleName(), + block.getPositionCount(), blockType); + } } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/VecAllocatorHelper.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/VecAllocatorHelper.java deleted file mode 100644 index e97b5e4d5b3e168f39354e92e66c45d00c3c5af1..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/VecAllocatorHelper.java +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package nova.hetu.olk.tool; - -import io.airlift.log.Logger; -import io.prestosql.execution.TaskId; -import io.prestosql.execution.TaskState; -import io.prestosql.operator.DriverContext; -import io.prestosql.operator.TaskContext; -import io.prestosql.spi.block.Block; -import nova.hetu.olk.memory.OpenLooKengAllocatorFactory; -import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecAllocator; - -public class VecAllocatorHelper -{ - private static final Logger log = Logger.get(VecAllocatorHelper.class); - - private static final String VECTOR_ALLOCATOR_PROPERTY_NAME = "vector_allocator"; - - public static final long DEFAULT_RESERVATION = 1 << 20; // 1MB - - private VecAllocatorHelper() - { - } - - public static void setVectorAllocatorToTaskContext(TaskContext taskContext, VecAllocator vecAllocator) - { - taskContext.getTaskExtendProperties().put(VECTOR_ALLOCATOR_PROPERTY_NAME, vecAllocator); - } - - private static VecAllocator getVecAllocatorFromTaskContext(TaskContext taskContext) - { - Object obj = taskContext.getTaskExtendProperties().get(VECTOR_ALLOCATOR_PROPERTY_NAME); - if (obj instanceof VecAllocator) { - return (VecAllocator) obj; - } - return VecAllocator.GLOBAL_VECTOR_ALLOCATOR; - } - - public static VecAllocator getVecAllocatorFromBlocks(Block[] blocks) - { - for (Block block : blocks) { - if (block.isExtensionBlock()) { - return ((Vec) block.getValues()).getAllocator(); - } - } - return VecAllocator.GLOBAL_VECTOR_ALLOCATOR; - } - - /** - * create an operator level allocator based on driver context. - * - * @param driverContext diver context - * @param limit allocator limit - * @param jazz operator Class - * @return operator allocator - */ - public static VecAllocator createOperatorLevelAllocator(DriverContext driverContext, long limit, Class jazz) - { - TaskContext taskContext = driverContext.getPipelineContext().getTaskContext(); - VecAllocator vecAllocator = getVecAllocatorFromTaskContext(taskContext); - return createOperatorLevelAllocator(vecAllocator, limit, taskContext.getTaskId().toString(), 0, jazz); - } - - /** - * create an operator level allocator based on driver context. - * - * @param driverContext diver context - * @param limit allocator limit - * @param reservation reservation - * @param jazz operator Class - * @return operator allocator - */ - public static VecAllocator createOperatorLevelAllocator(DriverContext driverContext, long limit, long reservation, - Class jazz) - { - TaskContext taskContext = driverContext.getPipelineContext().getTaskContext(); - VecAllocator vecAllocator = getVecAllocatorFromTaskContext(taskContext); - return createOperatorLevelAllocator(vecAllocator, limit, taskContext.getTaskId().toString(), reservation, jazz); - } - - /** - * create an operator level allocator base on a vecAllocator. - * - * @param parent parent vecAllocator - * @param limit allocator limit - * @param prefix taskId - * @param reservation allocator default reservation - * @param jazz operator Class - * @return operator allocator - */ - private static VecAllocator createOperatorLevelAllocator(VecAllocator parent, long limit, String prefix, - long reservation, Class jazz) - { - if (parent == VecAllocator.GLOBAL_VECTOR_ALLOCATOR || parent == null) { - return VecAllocator.GLOBAL_VECTOR_ALLOCATOR; - } - return parent.newChildAllocator(prefix + jazz.getSimpleName(), limit, reservation); - } - - /** - * create an operator level allocator based on task context. - * - * @param taskContext task context - * @param limit allocator limit - * @param jazz operator Class - * @return operator allocator - */ - public static VecAllocator createOperatorLevelAllocator(TaskContext taskContext, long limit, long reservation, - Class jazz) - { - VecAllocator vecAllocator = getVecAllocatorFromTaskContext(taskContext); - return createOperatorLevelAllocator(vecAllocator, limit, taskContext.getTaskId().toString(), 0, jazz); - } - - /** - * create task level allocator - * - * @param taskContext task context - * @return task vec allocator - */ - public static VecAllocator createTaskLevelAllocator(TaskContext taskContext) - { - TaskId taskId = taskContext.getTaskId(); - VecAllocator vecAllocator = OpenLooKengAllocatorFactory.create(taskId.toString(), () -> { - taskContext.getTaskStateMachine().addStateChangeListenerToTail(state -> { - if (state.isDone()) { - if (state == TaskState.FINISHED) { - OpenLooKengAllocatorFactory.delete(taskId.toString()); - } - else { - // CANCELED, ABORTED, FAILED and so on, wait for the completion of all drivers fo the task, - // here the allocator will be released when the gc recycles - VecAllocator removedAllocator = OpenLooKengAllocatorFactory.remove(taskId.toString()); - if (removedAllocator != null) { - log.debug("remove allocator from cache:" + removedAllocator.getScope()); - } - } - } - }); - }); - VecAllocatorHelper.setVectorAllocatorToTaskContext(taskContext, vecAllocator); - return vecAllocator; - } -} diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/VecBatchToPageIterator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/VecBatchToPageIterator.java index a73ddc1ea0260ad42692b30aefb71fe24a85e4c9..41ae262f3067b4f5419c02154237a064bb1cfe15 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/VecBatchToPageIterator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/VecBatchToPageIterator.java @@ -114,7 +114,6 @@ public class VecBatchToPageIterator throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Unsupported vector type " + vectors[i]); } } - vecBatch.close(); - return new Page(positionCount, blocks); + return new OmniPage(positionCount, vecBatch, blocks); } } diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/TestOmniLocalExecutionPlanner.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/TestOmniLocalExecutionPlanner.java index 290af573f582d9db08795d4a66659f1a52225440..53dcfa0b88382515221806dbc15ba8136a40373a 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/TestOmniLocalExecutionPlanner.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/TestOmniLocalExecutionPlanner.java @@ -24,7 +24,6 @@ import io.prestosql.operator.TaskContext; import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.sql.planner.NodePartitioningManager; import io.prestosql.sql.planner.PartitioningScheme; -import nova.hetu.olk.memory.OpenLooKengAllocatorFactory; import nova.hetu.olk.operator.AggregationOmniOperator; import nova.hetu.olk.operator.DistinctLimitOmniOperator; import nova.hetu.olk.operator.HashAggregationOmniOperator; @@ -38,7 +37,6 @@ import nova.hetu.olk.operator.TopNOmniOperator; import nova.hetu.olk.operator.WindowOmniOperator; import nova.hetu.olk.operator.filterandproject.OmniExpressionCompiler; import nova.hetu.omniruntime.constants.FunctionType; -import nova.hetu.omniruntime.vector.VecAllocator; import org.powermock.api.support.membermodification.MemberModifier; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareForTest; @@ -58,14 +56,12 @@ import static io.prestosql.SessionTestUtils.TEST_SESSION; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyList; -import static org.mockito.Matchers.anyString; import static org.powermock.api.mockito.PowerMockito.mock; import static org.powermock.api.mockito.PowerMockito.mockStatic; import static org.powermock.api.mockito.PowerMockito.when; import static org.powermock.api.mockito.PowerMockito.whenNew; -@PrepareForTest({VecAllocator.class, - OpenLooKengAllocatorFactory.class, +@PrepareForTest({ OmniLocalQueryRunner.class, OmniLocalExecutionPlanner.class, AggregationOmniOperator.class, @@ -76,7 +72,7 @@ import static org.powermock.api.mockito.PowerMockito.whenNew; PartitionFunction.class, NodePartitioningManager.class }) -@SuppressStaticInitializationFor({"nova.hetu.omniruntime.vector.VecAllocator", +@SuppressStaticInitializationFor({ "nova.hetu.omniruntime.constants.Constant", "nova.hetu.omniruntime.operator.OmniOperatorFactory" }) @@ -151,11 +147,6 @@ public class TestOmniLocalExecutionPlanner private void mockSupports() throws Exception { - //mock VecAllocator - VecAllocator vecAllocator = mock(VecAllocator.class); - mockStatic(OpenLooKengAllocatorFactory.class); - when(OpenLooKengAllocatorFactory.create(anyString(), any(OpenLooKengAllocatorFactory.CallBack.class))).thenReturn(vecAllocator); - //mock AggOmniOperator AggregationOmniOperator aggregationOmniOperator = mock(AggregationOmniOperator.class); AggregationOmniOperator.AggregationOmniOperatorFactory aggregationOmniOperatorFactory = mock(AggregationOmniOperator.AggregationOmniOperatorFactory.class); diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/AbstractBlockTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/AbstractBlockTest.java index 1ee1564cc68b4e3c753bf310eaa8c0dc4a137317..5f584324080da48d8b9b7e0ea8e44780a58bff20 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/AbstractBlockTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/AbstractBlockTest.java @@ -16,8 +16,6 @@ package nova.hetu.olk.block; import io.prestosql.spi.block.Block; -import nova.hetu.olk.mock.MockUtil; -import nova.hetu.omniruntime.vector.VecAllocator; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor; import org.powermock.modules.testng.PowerMockTestCase; @@ -32,21 +30,11 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; -@SuppressStaticInitializationFor({ - "nova.hetu.omniruntime.vector.VecAllocator", - "nova.hetu.omniruntime.vector.Vec" -}) +@SuppressStaticInitializationFor("nova.hetu.omniruntime.vector.Vec") @PowerMockIgnore("javax.management.*") public class AbstractBlockTest extends PowerMockTestCase { - private VecAllocator vecAllocator; - - protected final VecAllocator getVecAllocator() - { - return vecAllocator; - } - @DataProvider(name = "blockProvider") public Object[][] dataProvider() { @@ -69,7 +57,6 @@ public class AbstractBlockTest @BeforeMethod public void setUp() { - vecAllocator = MockUtil.mockNewVecWithAnyArguments(VecAllocator.class); this.setupMock(); } diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/BenchmarkOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/BenchmarkOmniBlock.java index 346f35bd62f3f6168912e09d7a6b8f3f82d3eb04..a4e0dc4791c1883fb35a9e80dba5d8ea7253538a 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/BenchmarkOmniBlock.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/BenchmarkOmniBlock.java @@ -23,7 +23,6 @@ import io.prestosql.spi.type.Type; import nova.hetu.olk.operator.benchmark.PageBuilderUtil; import nova.hetu.olk.tool.BlockUtils; import nova.hetu.olk.tool.OperatorUtils; -import nova.hetu.omniruntime.vector.VecAllocator; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -124,7 +123,7 @@ public class BenchmarkOmniBlock pages.add(PageBuilderUtil.createSequencePage(typesArray, ROWS_PER_PAGE)); } } - return OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, pages); + return OperatorUtils.transferToOffHeapPages(pages); } public List getPages() diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/ByteArrayOmniBlockTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/ByteArrayOmniBlockTest.java index d4af26d200e038ad94115604c7105d50e1d9c968..98dce525274a90db52ee05f6718f30dd8eb7bb18 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/ByteArrayOmniBlockTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/ByteArrayOmniBlockTest.java @@ -61,9 +61,9 @@ public class ByteArrayOmniBlockTest protected Block[] blocksForTest() { return new Block[]{ - mockBlock(false, false, getVecAllocator(), fill(new Boolean[1], index -> new Random().nextBoolean())), - mockBlock(false, false, getVecAllocator(), fill(new Boolean[2], index -> new Random().nextBoolean())), - mockBlock(false, false, getVecAllocator(), fill(new Boolean[3], index -> new Random().nextBoolean())), + mockBlock(false, false, fill(new Boolean[1], index -> new Random().nextBoolean())), + mockBlock(false, false, fill(new Boolean[2], index -> new Random().nextBoolean())), + mockBlock(false, false, fill(new Boolean[3], index -> new Random().nextBoolean())), }; } @@ -71,7 +71,6 @@ public class ByteArrayOmniBlockTest public void testFunctionCall(int index) { Block block = getBlockForTest(index); - block.copyRegion(0, block.getPositionCount()); block.copyPositions(new int[block.getPositionCount()], 0, block.getPositionCount()); block.retainedBytesForEachPart((offset, position) -> {}); block.getByte(0, 0); diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/DictionaryOmniBlockTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/DictionaryOmniBlockTest.java index f782c3000be09996a735a3ee7f8994abb81e62af..7a1cfea1ac7bce79fff346039ad6a76011090b65 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/DictionaryOmniBlockTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/DictionaryOmniBlockTest.java @@ -95,9 +95,9 @@ public class DictionaryOmniBlockTest protected Block[] blocksForTest() { return new Block[]{ - mockBlock(false, true, getVecAllocator(), fill(new Boolean[1], index -> new Random().nextBoolean())), - mockBlock(false, true, getVecAllocator(), fill(new Boolean[2], index -> new Random().nextBoolean())), - mockBlock(false, true, getVecAllocator(), fill(new Boolean[3], index -> new Random().nextBoolean())), + mockBlock(false, true, fill(new Boolean[1], index -> new Random().nextBoolean())), + mockBlock(false, true, fill(new Boolean[2], index -> new Random().nextBoolean())), + mockBlock(false, true, fill(new Boolean[3], index -> new Random().nextBoolean())), }; } diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/DoubleArrayOmniBlockTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/DoubleArrayOmniBlockTest.java index 91f58589a99faa85b38928f4e83e2938f7d6f641..3d3cc0a415ee8763b3bb7ca2713295ec80dd8559 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/DoubleArrayOmniBlockTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/DoubleArrayOmniBlockTest.java @@ -61,9 +61,9 @@ public class DoubleArrayOmniBlockTest protected Block[] blocksForTest() { return new Block[]{ - mockBlock(false, false, getVecAllocator(), fill(new Double[1], index -> new Random().nextDouble())), - mockBlock(false, false, getVecAllocator(), fill(new Double[2], index -> new Random().nextDouble())), - mockBlock(false, false, getVecAllocator(), fill(new Double[3], index -> new Random().nextDouble())), + mockBlock(false, false, fill(new Double[1], index -> new Random().nextDouble())), + mockBlock(false, false, fill(new Double[2], index -> new Random().nextDouble())), + mockBlock(false, false, fill(new Double[3], index -> new Random().nextDouble())), }; } @@ -71,7 +71,6 @@ public class DoubleArrayOmniBlockTest public void testFunctionCall(int index) { Block block = getBlockForTest(index); - block.copyRegion(0, block.getPositionCount()); block.copyPositions(new int[block.getPositionCount()], 0, block.getPositionCount()); block.retainedBytesForEachPart((offset, position) -> {}); block.writePositionTo(0, mock(BlockBuilder.class)); diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/Int128ArrayOmniBlockTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/Int128ArrayOmniBlockTest.java index af7ff3bab7968395875dd0a5b215e7109a2670c4..d186814a4ad369a2379d139c565fc8e10859ccc2 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/Int128ArrayOmniBlockTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/Int128ArrayOmniBlockTest.java @@ -56,9 +56,9 @@ public class Int128ArrayOmniBlockTest protected Block[] blocksForTest() { return new Block[]{ - mockBlock(false, false, getVecAllocator(), fill(new Long[1][], index -> fill(new Long[2], idx -> new Random().nextLong()))), - mockBlock(false, false, getVecAllocator(), fill(new Long[2][], index -> fill(new Long[2], idx -> new Random().nextLong()))), - mockBlock(false, false, getVecAllocator(), fill(new Long[3][], index -> fill(new Long[2], idx -> new Random().nextLong()))) + mockBlock(false, false, fill(new Long[1][], index -> fill(new Long[2], idx -> new Random().nextLong()))), + mockBlock(false, false, fill(new Long[2][], index -> fill(new Long[2], idx -> new Random().nextLong()))), + mockBlock(false, false, fill(new Long[3][], index -> fill(new Long[2], idx -> new Random().nextLong()))) }; } @@ -66,7 +66,6 @@ public class Int128ArrayOmniBlockTest public void testFunctionCall(int index) { Block block = getBlockForTest(index); - block.copyRegion(0, block.getPositionCount()); block.copyPositions(new int[block.getPositionCount()], 0, block.getPositionCount()); block.retainedBytesForEachPart((offset, position) -> {}); block.getSingleValueBlock(0); diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/IntArrayOmniBlockTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/IntArrayOmniBlockTest.java index d1a9d2235ad56f613b8605d256f69a373a71c395..a205d2435fdde155cf61b8de6d8cbb2b190bed93 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/IntArrayOmniBlockTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/IntArrayOmniBlockTest.java @@ -51,9 +51,9 @@ public class IntArrayOmniBlockTest protected Block[] blocksForTest() { return new Block[]{ - mockBlock(false, false, getVecAllocator(), fill(new Integer[1], index -> new Random().nextInt())), - mockBlock(false, false, getVecAllocator(), fill(new Integer[2], index -> new Random().nextInt())), - mockBlock(false, false, getVecAllocator(), fill(new Integer[3], index -> new Random().nextInt())), + mockBlock(false, false, fill(new Integer[1], index -> new Random().nextInt())), + mockBlock(false, false, fill(new Integer[2], index -> new Random().nextInt())), + mockBlock(false, false, fill(new Integer[3], index -> new Random().nextInt())), }; } @@ -71,7 +71,6 @@ public class IntArrayOmniBlockTest public void testFunctionCall(int index) { Block block = getBlockForTest(index); - block.copyRegion(0, block.getPositionCount()); block.copyPositions(new int[block.getPositionCount()], 0, block.getPositionCount()); block.retainedBytesForEachPart((offset, position) -> {}); block.getLong(0, 0); diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/LazyOmniBlockTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/LazyOmniBlockTest.java index 20dd08abcc4a75ad94e9b370f229b071740e0ce6..1910b49c5c2accda88410c201e3d5576db589548 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/LazyOmniBlockTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/LazyOmniBlockTest.java @@ -22,7 +22,6 @@ import nova.hetu.omniruntime.vector.BooleanVec; import nova.hetu.omniruntime.vector.DoubleVec; import nova.hetu.omniruntime.vector.FixedWidthVec; import nova.hetu.omniruntime.vector.IntVec; -import nova.hetu.omniruntime.vector.LazyVec; import org.powermock.core.classloader.annotations.PrepareForTest; import org.testng.annotations.Test; @@ -49,7 +48,6 @@ public class LazyOmniBlockTest protected void setupMock() { super.setupMock(); - mockNewVecWithAnyArguments(LazyVec.class); mockNewVecWithAnyArguments(BooleanVec.class); mockNewVecWithAnyArguments(IntVec.class); mockNewVecWithAnyArguments(DoubleVec.class); @@ -62,7 +60,6 @@ public class LazyOmniBlockTest LazyOmniBlock original = (LazyOmniBlock) block; LazyBlock lazyBlock = original.getLazyBlock(); assertTrue(block.isExtensionBlock()); - assertTrue(block.getValues() instanceof LazyVec); for (int i = 0; i < original.getPositionCount(); i++) { assertEquals(original.getEncodingName(), lazyBlock.getEncodingName()); } @@ -73,9 +70,9 @@ public class LazyOmniBlockTest { setupMock(); return new Block[]{ - mockBlock(true, false, getVecAllocator(), fill(new Boolean[3], index -> new Random().nextBoolean())), - mockBlock(true, false, getVecAllocator(), fill(new Integer[3], index -> new Random().nextInt())), - mockBlock(true, false, getVecAllocator(), fill(new Double[3], index -> new Random().nextDouble())), + mockBlock(true, false, fill(new Boolean[3], index -> new Random().nextBoolean())), + mockBlock(true, false, fill(new Integer[3], index -> new Random().nextInt())), + mockBlock(true, false, fill(new Double[3], index -> new Random().nextDouble())), }; } diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/LongArrayOmniBlockTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/LongArrayOmniBlockTest.java index c34498a6b13c99fac39890fd15b4f0cc78c8a9bf..b959b6f5437a93faf2a7c18e6af03aadd29664e4 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/LongArrayOmniBlockTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/LongArrayOmniBlockTest.java @@ -61,9 +61,9 @@ public class LongArrayOmniBlockTest protected Block[] blocksForTest() { return new Block[]{ - mockBlock(false, false, getVecAllocator(), fill(new Long[1], index -> new Random().nextLong())), - mockBlock(false, false, getVecAllocator(), fill(new Long[2], index -> new Random().nextLong())), - mockBlock(false, false, getVecAllocator(), fill(new Long[3], index -> new Random().nextLong())), + mockBlock(false, false, fill(new Long[1], index -> new Random().nextLong())), + mockBlock(false, false, fill(new Long[2], index -> new Random().nextLong())), + mockBlock(false, false, fill(new Long[3], index -> new Random().nextLong())), }; } @@ -71,7 +71,6 @@ public class LongArrayOmniBlockTest public void testFunctionCall(int index) { Block block = getBlockForTest(index); - block.copyRegion(0, block.getPositionCount()); block.copyPositions(new int[block.getPositionCount()], 0, block.getPositionCount()); block.retainedBytesForEachPart((offset, position) -> {}); for (int i = 0; i < block.getPositionCount(); i++) { diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/RowOmniBlockTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/RowOmniBlockTest.java index c1e2c125657e77d6eb8e7ee03583ffb6001b0e10..d926afced9a44c41d3812747aa1c939829f9905b 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/RowOmniBlockTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/RowOmniBlockTest.java @@ -46,7 +46,7 @@ public class RowOmniBlockTest { private Block rowBlock(Block block, Type dataType) { - return fromFieldBlocks(getVecAllocator(), block.getPositionCount(), Optional.empty(), new Block[]{block}, dataType, mockVec(ContainerVec.class, new Block[]{block}, getVecAllocator())); + return fromFieldBlocks(block.getPositionCount(), Optional.empty(), new Block[]{block}, dataType, mockVec(ContainerVec.class, new Block[]{block})); } @Override @@ -75,11 +75,11 @@ public class RowOmniBlockTest protected Block[] blocksForTest() { return new Block[]{ - rowBlock(requireNonNull(mockBlock(false, false, getVecAllocator(), fill(new String[1], index -> UUID.randomUUID().toString()))), + rowBlock(requireNonNull(mockBlock(false, false, fill(new String[1], index -> UUID.randomUUID().toString()))), RowType.from(ImmutableList.of(RowType.field(VARCHAR)))), - rowBlock(requireNonNull(mockBlock(false, false, getVecAllocator(), fill(new String[2], index -> UUID.randomUUID().toString()))), + rowBlock(requireNonNull(mockBlock(false, false, fill(new String[2], index -> UUID.randomUUID().toString()))), RowType.from(ImmutableList.of(RowType.field(VARCHAR)))), - rowBlock(requireNonNull(mockBlock(false, false, getVecAllocator(), fill(new String[3], index -> UUID.randomUUID().toString()))), + rowBlock(requireNonNull(mockBlock(false, false, fill(new String[3], index -> UUID.randomUUID().toString()))), RowType.from(ImmutableList.of(RowType.field(VARCHAR)))) }; } diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/ShortArrayOmniBlockTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/ShortArrayOmniBlockTest.java index 3c81d05beec4a2a6d86dd38a5bdee502c416eaac..fc8934f44b9a1349132e47de3abde83e0735458a 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/ShortArrayOmniBlockTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/ShortArrayOmniBlockTest.java @@ -51,9 +51,9 @@ public class ShortArrayOmniBlockTest protected Block[] blocksForTest() { return new Block[]{ - mockBlock(false, false, getVecAllocator(), fill(new Short[1], index -> (short) new Random().nextInt(Short.MAX_VALUE))), - mockBlock(false, false, getVecAllocator(), fill(new Short[2], index -> (short) new Random().nextInt(Short.MAX_VALUE))), - mockBlock(false, false, getVecAllocator(), fill(new Short[3], index -> (short) new Random().nextInt(Short.MAX_VALUE))), + mockBlock(false, false, fill(new Short[1], index -> (short) new Random().nextInt(Short.MAX_VALUE))), + mockBlock(false, false, fill(new Short[2], index -> (short) new Random().nextInt(Short.MAX_VALUE))), + mockBlock(false, false, fill(new Short[3], index -> (short) new Random().nextInt(Short.MAX_VALUE))), }; } @@ -71,7 +71,6 @@ public class ShortArrayOmniBlockTest public void testFunctionCall(int index) { Block block = getBlockForTest(index); - block.copyRegion(0, block.getPositionCount()); block.copyPositions(new int[block.getPositionCount()], 0, block.getPositionCount()); block.retainedBytesForEachPart((offset, position) -> {}); block.writePositionTo(0, mock(BlockBuilder.class)); diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/VariableWidthOmniBlockTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/VariableWidthOmniBlockTest.java index 5b4ebef72d3b6bdc7a5edad1e73295c2f6e7ab1e..627881b220a8d061f8f498a01c4eacfcfc116c55 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/VariableWidthOmniBlockTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/block/VariableWidthOmniBlockTest.java @@ -49,9 +49,9 @@ public class VariableWidthOmniBlockTest protected Block[] blocksForTest() { return new Block[]{ - mockBlock(false, false, getVecAllocator(), fill(new String[1], index -> UUID.randomUUID().toString())), - mockBlock(false, false, getVecAllocator(), fill(new String[2], index -> UUID.randomUUID().toString())), - mockBlock(false, false, getVecAllocator(), fill(new String[3], index -> UUID.randomUUID().toString())) + mockBlock(false, false, fill(new String[1], index -> UUID.randomUUID().toString())), + mockBlock(false, false, fill(new String[2], index -> UUID.randomUUID().toString())), + mockBlock(false, false, fill(new String[3], index -> UUID.randomUUID().toString())) }; } @@ -59,7 +59,6 @@ public class VariableWidthOmniBlockTest public void testFunctionCall(int index) { Block block = getBlockForTest(index); - block.copyRegion(0, block.getPositionCount()); block.copyPositions(new int[block.getPositionCount()], 0, block.getPositionCount()); } diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/mock/MockUtil.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/mock/MockUtil.java index e2b53f848a074a84111f79ac86376ba74dace314..eed047142e808de67955bbe0d4e24b298afde5bd 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/mock/MockUtil.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/mock/MockUtil.java @@ -41,9 +41,9 @@ import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.ShortVec; import nova.hetu.omniruntime.vector.VarcharVec; import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; +import java.nio.charset.Charset; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -57,12 +57,14 @@ import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL128; import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DOUBLE; import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_SHORT; import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; import static nova.hetu.omniruntime.vector.VecEncoding.OMNI_VEC_ENCODING_FLAT; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; import static org.powermock.api.mockito.PowerMockito.doAnswer; import static org.powermock.api.mockito.PowerMockito.mock; +import static org.powermock.api.mockito.PowerMockito.mockStatic; import static org.powermock.api.mockito.PowerMockito.when; import static org.powermock.api.mockito.PowerMockito.whenNew; @@ -112,40 +114,39 @@ public class MockUtil public static Page mockPage(BlockModel... blockModels) { Map> blocks = new HashMap<>(); - VecAllocator vecAllocator = mock(VecAllocator.class); for (int j = 0; j < blockModels.length; j++) { BlockModel blockModel = blockModels[j]; if (blockModel.rowBlock) { for (Object value : blockModel.values) { Page page = mockPage(block(blockModel.lazy, blockModel.dictionary, (Object[]) value)); - blocks.put(j, RowOmniBlock.fromFieldBlocks(vecAllocator, page.getPositionCount(), Optional.empty(), page.getBlocks(), null, mockVec(ContainerVec.class, page.getBlocks(), vecAllocator))); + blocks.put(j, RowOmniBlock.fromFieldBlocks(page.getPositionCount(), Optional.empty(), page.getBlocks(), null, mockVec(ContainerVec.class, page.getBlocks()))); } } Block block = null; Vec vec = null; if (blockModel.values instanceof Boolean[]) { - vec = mockVec(BooleanVec.class, blockModel.values, vecAllocator); + vec = mockVec(BooleanVec.class, blockModel.values); block = new ByteArrayOmniBlock(blockModel.values.length, (BooleanVec) vec); } if (blockModel.values instanceof Integer[]) { - vec = mockVec(IntVec.class, blockModel.values, vecAllocator); + vec = mockVec(IntVec.class, blockModel.values); block = new IntArrayOmniBlock(blockModel.values.length, (IntVec) vec); } if (blockModel.values instanceof Double[]) { - vec = mockVec(DoubleVec.class, blockModel.values, vecAllocator); + vec = mockVec(DoubleVec.class, blockModel.values); block = new DoubleArrayOmniBlock(blockModel.values.length, (DoubleVec) vec); } if (blockModel.values instanceof Short[]) { - vec = mockVec(ShortVec.class, blockModel.values, vecAllocator); + vec = mockVec(ShortVec.class, blockModel.values); block = new ShortArrayOmniBlock(blockModel.values.length, (ShortVec) vec); } if (blockModel.values instanceof Long[]) { - vec = mockVec(LongVec.class, blockModel.values, vecAllocator); + vec = mockVec(LongVec.class, blockModel.values); block = new LongArrayOmniBlock(blockModel.values.length, (LongVec) vec); } if (blockModel.values instanceof Long[][]) { - vec = mockVec(Decimal128Vec.class, blockModel.values, vecAllocator); + vec = mockVec(Decimal128Vec.class, blockModel.values); doAnswer(invocationOnMock -> { Long[] result = (Long[]) blockModel.values[(int) invocationOnMock.getArguments()[0]]; @@ -158,14 +159,13 @@ public class MockUtil block = new Int128ArrayOmniBlock(blockModel.values.length, (Decimal128Vec) vec); } if (blockModel.values instanceof String[]) { - vec = mockVec(VarcharVec.class, blockModel.values, vecAllocator); + vec = mockVec(VarcharVec.class, blockModel.values); int[] offsets = new int[blockModel.values.length + 1]; int startPosition = 0; for (int i = 0; i < blockModel.values.length; i++) { offsets[i + 1] = startPosition; startPosition += ((String[]) blockModel.values)[i].length(); } - when(((VarcharVec) vec).getRawValueOffset()).thenReturn(offsets); block = new VariableWidthOmniBlock(blockModel.values.length, (VarcharVec) vec); } @@ -180,7 +180,6 @@ public class MockUtil when(dictionaryVec.getSize()).thenReturn(blockModel.values.length); when(dictionaryVec.getDictionary()).thenReturn(vec); when(dictionaryVec.slice(anyInt(), anyInt())).thenReturn(dictionaryVec); - when(dictionaryVec.getAllocator()).thenReturn(vecAllocator); blocks.put(j, new DictionaryOmniBlock(0, blockModel.values.length, dictionaryVec, idIndex, false, randomDictionaryId())); } else { @@ -188,54 +187,46 @@ public class MockUtil } } } - return blocks.size() == 0 ? null : new Page(blocks.entrySet().stream().map(entry -> { - if (blockModels[entry.getKey()].lazy) { - LazyBlock lazyBlock = new LazyBlock(entry.getValue().getPositionCount(), instance -> {}); - lazyBlock.setBlock(entry.getValue()); - return new LazyOmniBlock(vecAllocator, lazyBlock, null); - } - else { - return entry.getValue(); - } - }).toArray(Block[]::new)); + return blocks.size() == 0 ? null : new Page(blocks.entrySet().stream().map(entry -> entry.getValue()).toArray(Block[]::new)); } - public static Block mockBlock(boolean lazy, boolean dictionary, VecAllocator vecAllocator, Object[] object) + public static Block mockBlock(boolean lazy, boolean dictionary, Object[] object) { Block block = null; Vec vec = null; DataType dataType = mock(DataType.class); if (object instanceof Boolean[]) { - vec = mockVec(BooleanVec.class, object, vecAllocator); + vec = mockVec(BooleanVec.class, object); when(dataType.getId()).thenReturn(OMNI_BOOLEAN); when(((BooleanVec) vec).get(anyInt())).thenAnswer(invocationOnMock -> object[(int) invocationOnMock.getArguments()[0]]); block = new ByteArrayOmniBlock(object.length, (BooleanVec) vec); } if (object instanceof Integer[]) { - vec = mockVec(IntVec.class, object, vecAllocator); + vec = mockVec(IntVec.class, object); when(dataType.getId()).thenReturn(OMNI_INT); when(((IntVec) vec).get(anyInt())).thenAnswer(invocationOnMock -> object[(int) invocationOnMock.getArguments()[0]]); block = new IntArrayOmniBlock(object.length, (IntVec) vec); } if (object instanceof Double[]) { - vec = mockVec(DoubleVec.class, object, vecAllocator); + vec = mockVec(DoubleVec.class, object); when(dataType.getId()).thenReturn(OMNI_DOUBLE); when(((DoubleVec) vec).get(anyInt())).thenAnswer(invocationOnMock -> object[(int) invocationOnMock.getArguments()[0]]); block = new DoubleArrayOmniBlock(object.length, (DoubleVec) vec); } if (object instanceof Short[]) { - vec = mockVec(ShortVec.class, object, vecAllocator); + vec = mockVec(ShortVec.class, object); + when(dataType.getId()).thenReturn(OMNI_SHORT); when(((ShortVec) vec).get(anyInt())).thenAnswer(invocationOnMock -> object[(int) invocationOnMock.getArguments()[0]]); block = new ShortArrayOmniBlock(object.length, (ShortVec) vec); } if (object instanceof Long[]) { - vec = mockVec(LongVec.class, object, vecAllocator); + vec = mockVec(LongVec.class, object); when(dataType.getId()).thenReturn(OMNI_LONG); when(((LongVec) vec).get(anyInt())).thenAnswer(invocationOnMock -> object[(int) invocationOnMock.getArguments()[0]]); block = new LongArrayOmniBlock(object.length, (LongVec) vec); } if (object instanceof Long[][]) { - vec = mockVec(Decimal128Vec.class, object, vecAllocator); + vec = mockVec(Decimal128Vec.class, object); when(dataType.getId()).thenReturn(OMNI_DECIMAL128); doAnswer(invocationOnMock -> { Long[] result = (Long[]) object[(int) invocationOnMock.getArguments()[0]]; @@ -248,7 +239,7 @@ public class MockUtil block = new Int128ArrayOmniBlock(object.length, (Decimal128Vec) vec); } if (object instanceof String[]) { - vec = mockVec(VarcharVec.class, object, vecAllocator); + vec = mockVec(VarcharVec.class, object); when(dataType.getId()).thenReturn(OMNI_VARCHAR); int[] offsets = new int[object.length + 1]; int startPosition = 0; @@ -256,13 +247,14 @@ public class MockUtil offsets[i + 1] = startPosition; startPosition += ((String[]) object)[i].length(); } - when(((VarcharVec) vec).getRawValueOffset()).thenReturn(offsets); - when(((VarcharVec) vec).get(anyInt())).thenAnswer(invocationOnMock -> object[(int) invocationOnMock.getArguments()[0]]); + when(((VarcharVec) vec).get(anyInt())).thenAnswer(invocationOnMock -> + String.valueOf(object[(Integer) invocationOnMock.getArguments()[0]]).getBytes(Charset.defaultCharset())); block = new VariableWidthOmniBlock(object.length, (VarcharVec) vec); } if (block != null) { if (dictionary) { DictionaryVec dictionaryVec = mock(DictionaryVec.class); + mockStatic(DictionaryOmniBlock.class); int[] idIndex = new int[object.length]; for (int i = 0; i < object.length; i++) { idIndex[i] = i; @@ -270,8 +262,9 @@ public class MockUtil when(dictionaryVec.getIds()).thenReturn(idIndex); when(dictionaryVec.getSize()).thenReturn(object.length); when(dictionaryVec.getDictionary()).thenReturn(vec); + when(DictionaryOmniBlock.expandDictionary(dictionaryVec)).thenReturn(block); + when(DictionaryOmniBlock.getIds(object.length)).thenReturn(idIndex); when(dictionaryVec.slice(anyInt(), anyInt())).thenReturn(dictionaryVec); - when(dictionaryVec.getAllocator()).thenReturn(vecAllocator); when(vec.getEncoding()).thenReturn(OMNI_VEC_ENCODING_FLAT); when(vec.getType()).thenReturn(dataType); block = new DictionaryOmniBlock>(dictionaryVec, false, randomDictionaryId()); @@ -279,7 +272,7 @@ public class MockUtil if (lazy) { LazyBlock lazyBlock = new LazyBlock>(block.getPositionCount(), instance -> {}); lazyBlock.setBlock(block); - return new LazyOmniBlock>(vecAllocator, lazyBlock, null); + return new LazyOmniBlock>(lazyBlock, null); } return block; } @@ -299,12 +292,11 @@ public class MockUtil return instance; } - public static T mockVec(Class vecClass, Object[] values, VecAllocator vecAllocator) + public static T mockVec(Class vecClass, Object[] values) { T vec = mock(vecClass); when(vec.getSize()).thenReturn(values.length); when(vec.slice(anyInt(), anyInt())).thenReturn(vec); - when(vec.getAllocator()).thenReturn(vecAllocator); when(vec.copyPositions(any(), anyInt(), anyInt())).thenReturn(vec); when(vec.slice(anyInt(), anyInt())).thenReturn(vec); return vec; @@ -319,7 +311,6 @@ public class MockUtil return 0; }).when(omniOperator).addInput(any()); doAnswer(invocation -> innerVec.listIterator()).when(omniOperator).getOutput(); - when(omniOperator.getVecAllocator()).thenReturn(mock(VecAllocator.class)); return omniOperator; } } diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/AbstractOperatorTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/AbstractOperatorTest.java index 665aa34fe03b46c2da7690081e616a7e04c50ff7..12ff04b4fd0921f7ca365b9f89b73f548342b42e 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/AbstractOperatorTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/AbstractOperatorTest.java @@ -39,9 +39,9 @@ import io.prestosql.spi.snapshot.MarkerPage; import io.prestosql.spi.type.TimeZoneKey; import io.prestosql.sql.SqlPath; import io.prestosql.transaction.TransactionId; +import nova.hetu.olk.tool.OmniPage; import nova.hetu.omniruntime.vector.BooleanVec; import nova.hetu.omniruntime.vector.FixedWidthVec; -import nova.hetu.omniruntime.vector.LazyVec; import nova.hetu.omniruntime.vector.Vec; import nova.hetu.omniruntime.vector.VecBatch; import org.mockito.Mock; @@ -76,7 +76,6 @@ import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.when; @SuppressStaticInitializationFor({ - "nova.hetu.omniruntime.vector.VecAllocator", "nova.hetu.omniruntime.vector.Vec", "nova.hetu.omniruntime.constants.Constant", "nova.hetu.omniruntime.operator.OmniOperatorFactory" @@ -180,10 +179,10 @@ public class AbstractOperatorTest { VecBatch vecBatch = mockNewVecWithAnyArguments(VecBatch.class); when(vecBatch.getVectors()).thenReturn(new Vec[0]); - mockNewVecWithAnyArguments(LazyVec.class); mockNewVecWithAnyArguments(Vec.class); mockNewVecWithAnyArguments(BooleanVec.class); mockNewVecWithAnyArguments(FixedWidthVec.class); + mockNewVecWithAnyArguments(OmniPage.class); } protected OperatorFactory createOperatorFactory() diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/AggregationOmniOperatorTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/AggregationOmniOperatorTest.java index 4a316af0438f826770cec9388fc77ef9d890a5af..6354640ed075c0a09ecdb9142756f7fe58c320e3 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/AggregationOmniOperatorTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/AggregationOmniOperatorTest.java @@ -44,7 +44,6 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; -import static org.mockito.Matchers.any; import static org.powermock.api.mockito.PowerMockito.doReturn; @PrepareForTest({ @@ -75,7 +74,7 @@ public class AggregationOmniOperatorTest protected OperatorFactory createOperatorFactory() { OmniAggregationOperatorFactory factory = mockNewVecWithAnyArguments(OmniAggregationOperatorFactory.class); - doReturn(omniOperator).when(factory).createOperator(any()); + doReturn(omniOperator).when(factory).createOperator(); return new AggregationOmniOperatorFactory(operatorId, planNodeId, sourceTypes, aggregatorTypes, aggregationInputChannels, maskChannelList, aggregationOutputTypes, step); } diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/BuildOffHeapOmniOperatorTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/BuildOffHeapOmniOperatorTest.java index 8937618d530dd57ba549adbcd1c849e1504bf640..cdf3e0ed49710307af701e4627e95c14abead162 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/BuildOffHeapOmniOperatorTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/BuildOffHeapOmniOperatorTest.java @@ -21,6 +21,10 @@ import io.prestosql.spi.Page; import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.Type; import nova.hetu.olk.operator.BuildOffHeapOmniOperator.BuildOffHeapOmniOperatorFactory; +import nova.hetu.olk.tool.OmniPage; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; import org.testng.annotations.Test; import java.util.Collections; @@ -33,6 +37,8 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +@RunWith(PowerMockRunner.class) +@PrepareForTest(OmniPage.class) public class BuildOffHeapOmniOperatorTest extends AbstractOperatorTest { @@ -45,6 +51,12 @@ public class BuildOffHeapOmniOperatorTest return new BuildOffHeapOmniOperatorFactory(operatorId, planNodeId, inputTypes); } + @Override + protected void setUpMock() + { + super.setUpMock(); + } + @Test(dataProvider = "pageProvider") public void testProcess(int i) { diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/DistinctLimitOmniOperatorTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/DistinctLimitOmniOperatorTest.java index 6289bf290b562c5a853a461ae2b4f9f74d6e9ae2..4dfc9643be1aefb915762210618a894e29d80d20 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/DistinctLimitOmniOperatorTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/DistinctLimitOmniOperatorTest.java @@ -19,7 +19,6 @@ import io.prestosql.operator.Operator; import io.prestosql.operator.OperatorFactory; import io.prestosql.spi.Page; import io.prestosql.spi.PageBuilder; -import io.prestosql.spi.block.LazyBlock; import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.Type; import nova.hetu.olk.block.LazyOmniBlock; @@ -95,7 +94,7 @@ public class DistinctLimitOmniOperatorTest return new Page[]{ mockPage(), PageBuilder.withMaxPageSize(1, asList()).build(), - new Page(new LazyBlock(10, block -> {}), new LazyBlock(10, block -> {})) + PageBuilder.withMaxPageSize(1, asList()).build() }; } diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/DynamicFilterSourceOmniOperatorTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/DynamicFilterSourceOmniOperatorTest.java index b140633a4869b2f54de9e4ea496711b754752d53..e5dac51f01762d9db8662be86cf16be49c686ed7 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/DynamicFilterSourceOmniOperatorTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/DynamicFilterSourceOmniOperatorTest.java @@ -71,7 +71,7 @@ public class DynamicFilterSourceOmniOperatorTest protected Operator createOperator(Operator originalOperator) { return new DynamicFilterSourceOmniOperator(originalOperator.getOperatorContext(), mapConsumer, channels, - planNodeId, maxFilterPositionsCount, maxFilterSize, null); + planNodeId, maxFilterPositionsCount, maxFilterSize); } @Test(dataProvider = "pageProvider") diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/EnforceSingleRowOmniOperatorTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/EnforceSingleRowOmniOperatorTest.java index 2ce220cacd1e58bb01b9b04566c8ea36bb32ddf9..86c7737d5600755dec0fc18ef8aaabfd5825b166 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/EnforceSingleRowOmniOperatorTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/EnforceSingleRowOmniOperatorTest.java @@ -23,7 +23,7 @@ import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.Type; import nova.hetu.olk.block.ByteArrayOmniBlock; import nova.hetu.olk.operator.EnforceSingleRowOmniOperator.EnforceSingleRowOmniOperatorFactory; -import nova.hetu.omniruntime.vector.VecAllocator; +import nova.hetu.olk.tool.OmniPage; import org.powermock.core.classloader.annotations.PrepareForTest; import org.testng.annotations.Test; @@ -39,11 +39,11 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.mock; @PrepareForTest({ DistinctLimitOmniOperator.class, - ByteArrayOmniBlock.class + ByteArrayOmniBlock.class, + OmniPage.class }) public class EnforceSingleRowOmniOperatorTest extends AbstractOperatorTest @@ -68,7 +68,7 @@ public class EnforceSingleRowOmniOperatorTest @Override protected Operator createOperator(Operator originalOperator) { - return new EnforceSingleRowOmniOperator(originalOperator.getOperatorContext(), mock(VecAllocator.class)); + return new EnforceSingleRowOmniOperator(originalOperator.getOperatorContext()); } @Override diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/HashAggregationOmniOperatorTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/HashAggregationOmniOperatorTest.java index b0480a3f1860dde47445cd20f42633d0623c78f2..df3d534f09bcb5739c727d5a8d39825c65fd14d2 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/HashAggregationOmniOperatorTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/HashAggregationOmniOperatorTest.java @@ -43,7 +43,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; -import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @@ -75,7 +74,7 @@ public class HashAggregationOmniOperatorTest super.setUpMock(); omniHashAggregationOperatorFactory = mockNewVecWithAnyArguments(OmniHashAggregationOperatorFactory.class); omniOperator = mockOmniOperator(); - when(omniHashAggregationOperatorFactory.createOperator(any())).thenReturn(omniOperator); + when(omniHashAggregationOperatorFactory.createOperator()).thenReturn(omniOperator); } @Override diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/LimitOmniOperatorTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/LimitOmniOperatorTest.java index 7ae33d15f7a5c7a00b838d2b487ad6bfaf70a946..de00b14c0c72b982b60dc5674a3b5c0046e58497 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/LimitOmniOperatorTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/LimitOmniOperatorTest.java @@ -25,7 +25,6 @@ import io.prestosql.spi.type.Type; import nova.hetu.olk.block.LazyOmniBlock; import nova.hetu.olk.operator.LimitOmniOperator.LimitOmniOperatorFactory; import nova.hetu.olk.tool.OperatorUtils; -import nova.hetu.omniruntime.operator.limit.OmniLimitOperatorFactory; import org.powermock.core.classloader.annotations.PrepareForTest; import org.testng.annotations.Test; @@ -35,8 +34,6 @@ import java.util.List; import java.util.Random; import java.util.UUID; -import static nova.hetu.olk.mock.MockUtil.mockNewVecWithAnyArguments; -import static nova.hetu.olk.mock.MockUtil.mockOmniOperator; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; @@ -58,7 +55,6 @@ public class LimitOmniOperatorTest protected void setUpMock() { super.setUpMock(); - mockNewVecWithAnyArguments(OmniLimitOperatorFactory.class); } @Override @@ -70,7 +66,7 @@ public class LimitOmniOperatorTest @Override protected Operator createOperator(Operator originalOperator) { - return new LimitOmniOperator(originalOperator.getOperatorContext(), mockOmniOperator(), limit); + return new LimitOmniOperator(originalOperator.getOperatorContext(), limit); } @Override diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/LocalMergeSourceOmniOperatorTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/LocalMergeSourceOmniOperatorTest.java index 55445612227664226eece19c0891d59230d563ea..5ddb183b701f8285a91e817de9a6cbdedd64e17a 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/LocalMergeSourceOmniOperatorTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/LocalMergeSourceOmniOperatorTest.java @@ -23,7 +23,6 @@ import io.prestosql.operator.exchange.LocalExchange; import io.prestosql.operator.exchange.LocalExchange.LocalExchangeFactory; import io.prestosql.operator.exchange.LocalExchangeSource; import io.prestosql.spi.Page; -import io.prestosql.spi.block.LazyBlock; import io.prestosql.spi.block.SortOrder; import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.Type; @@ -43,8 +42,11 @@ import java.util.concurrent.atomic.AtomicBoolean; import static io.prestosql.operator.Operator.NOT_BLOCKED; import static java.util.Arrays.asList; +import static nova.hetu.olk.mock.MockUtil.block; +import static nova.hetu.olk.mock.MockUtil.fill; import static nova.hetu.olk.mock.MockUtil.mockNewVecWithAnyArguments; import static nova.hetu.olk.mock.MockUtil.mockOmniOperator; +import static nova.hetu.olk.mock.MockUtil.mockPage; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThrows; @@ -108,10 +110,10 @@ public class LocalMergeSourceOmniOperatorTest when(localExchange.getBufferCount()).thenReturn(2); when(localExchange.getNextSource()).thenReturn(firstLocalExchangeSource, secondLocalExchangeSource); when(firstLocalExchangeSource.getPages()).thenAnswer((invocation -> { - return firstSourceFinish.get() ? asList(new Page(new LazyBlock(10, block -> {}))) : null; + return firstSourceFinish.get() ? asList(mockPage(block(false, false, fill(new Integer[3], index -> new Random().nextInt())))) : null; })); when(secondLocalExchangeSource.getPages()).thenAnswer((invocation -> { - return secondSourceFinish.get() ? asList(new Page(new LazyBlock(10, block -> {}))) : null; + return secondSourceFinish.get() ? asList(mockPage(block(false, false, fill(new Integer[3], index -> new Random().nextInt())))) : null; })); when(firstLocalExchangeSource.waitForReading()).thenReturn(new AbstractFuture() { diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/OrderByOperatorTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/OrderByOperatorTest.java index 550db3855de5b1211c5e693eaa73bdcfe78d0c8a..74a08719b101d1893451796eb5d90abec94923af 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/OrderByOperatorTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/OrderByOperatorTest.java @@ -23,7 +23,6 @@ import io.prestosql.spi.type.Type; import nova.hetu.olk.tool.OperatorUtils; import nova.hetu.omniruntime.operator.OmniOperator; import nova.hetu.omniruntime.operator.sort.OmniSortOperatorFactory; -import nova.hetu.omniruntime.vector.VecAllocator; import org.powermock.core.classloader.annotations.PrepareForTest; import org.testng.annotations.Test; @@ -34,7 +33,6 @@ import java.util.UUID; import static nova.hetu.olk.mock.MockUtil.mockOmniOperator; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThrows; -import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -69,7 +67,7 @@ public class OrderByOperatorTest { OmniSortOperatorFactory omniSortOperatorFactory = mock(OmniSortOperatorFactory.class); OmniOperator omniOperator = mockOmniOperator(); - when(omniSortOperatorFactory.createOperator(any(VecAllocator.class))).thenReturn(omniOperator); + when(omniSortOperatorFactory.createOperator()).thenReturn(omniOperator); return omniSortOperatorFactory; } diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/AbstractOperatorBenchmarkContext.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/AbstractOperatorBenchmarkContext.java index 12acf697108eca2ce422afa3a2cef00c4f334844..fddfdf7f7a4a120c347dd19732b3c988cb1ef17f 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/AbstractOperatorBenchmarkContext.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/AbstractOperatorBenchmarkContext.java @@ -27,8 +27,6 @@ import io.prestosql.spi.type.Type; import io.prestosql.testing.TestingSnapshotUtils; import io.prestosql.testing.TestingTaskContext; import nova.hetu.olk.tool.BlockUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; -import nova.hetu.omniruntime.vector.VecAllocator; import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.TearDown; @@ -284,13 +282,10 @@ public abstract class AbstractOperatorBenchmarkContext abstract static class AbstractOmniOperatorBenchmarkContext extends AbstractOperatorBenchmarkContext { - private VecAllocator taskLevelAllocator; - @Override protected TaskContext createTaskContext() { TaskContext taskContext = super.createTaskContext(); - taskLevelAllocator = VecAllocatorHelper.createTaskLevelAllocator(taskContext); return taskContext; } @@ -301,7 +296,7 @@ public abstract class AbstractOperatorBenchmarkContext for (Page page : pages) { slicedPages.add(page.getRegion(0, page.getPositionCount())); } - return transferToOffHeapPages(taskLevelAllocator, slicedPages); + return transferToOffHeapPages(slicedPages); } } diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkEnforceSingleRowOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkEnforceSingleRowOmniOperator.java index 9d60222bca232f4cf8bd5d17c207330b80bb8e7b..10c22f2f72f70262704e5de4325c7da293c0b97b 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkEnforceSingleRowOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkEnforceSingleRowOmniOperator.java @@ -28,7 +28,6 @@ import io.prestosql.spi.type.Type; import io.prestosql.testing.TestingSnapshotUtils; import nova.hetu.olk.operator.EnforceSingleRowOmniOperator; import nova.hetu.olk.tool.BlockUtils; -import nova.hetu.omniruntime.vector.VecAllocator; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -153,7 +152,7 @@ public class BenchmarkEnforceSingleRowOmniOperator { List pages = rowPagesBuilder(INPUT_TYPES.get(testGroup)).addSequencePage(1, 1) .addSequencePage(2, 1).build(); - return transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, pages); + return transferToOffHeapPages(pages); } public static void main(String[] args) throws RunnerException diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkHashJoinOmniOperators.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkHashJoinOmniOperators.java index f29c85116e570f4506ce09594fc4956a6f728f1b..a2fc57dbbc2afe5998f73d59fda74d03a935a25f 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkHashJoinOmniOperators.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkHashJoinOmniOperators.java @@ -39,8 +39,6 @@ import io.prestosql.spi.type.VarcharType; import io.prestosql.type.TypeUtils; import nova.hetu.olk.operator.HashBuilderOmniOperator.HashBuilderOmniOperatorFactory; import nova.hetu.olk.operator.LookupJoinOmniOperators; -import nova.hetu.olk.tool.VecAllocatorHelper; -import nova.hetu.omniruntime.vector.VecAllocator; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -158,8 +156,6 @@ public class BenchmarkHashJoinOmniOperators protected List buildOutputChannels; protected List buildJoinChannels; protected OptionalInt buildHashChannel; - - private VecAllocator buildVecAllocator; protected JoinBridgeManager lookupSourceFactoryManager; @Override @@ -196,7 +192,7 @@ public class BenchmarkHashJoinOmniOperators for (Page page : pages) { slicedPages.add(page.getRegion(0, page.getPositionCount())); } - return transferToOffHeapPages(buildVecAllocator, slicedPages); + return transferToOffHeapPages(slicedPages); } @Override @@ -256,7 +252,6 @@ public class BenchmarkHashJoinOmniOperators protected TaskContext createTaskContext() { TaskContext testingTaskContext = createTaskContextBySizeInGigaByte(4); - buildVecAllocator = VecAllocatorHelper.createTaskLevelAllocator(testingTaskContext); return testingTaskContext; } @@ -399,8 +394,6 @@ public class BenchmarkHashJoinOmniOperators protected List probeOutputChannels; protected List probeJoinChannels; protected OptionalInt probeHashChannel; - - private VecAllocator probeVecAllocator; private DriverContext buildDriverContext; private Operator buildOperator; @@ -449,7 +442,7 @@ public class BenchmarkHashJoinOmniOperators for (Page page : pages) { slicedPages.add(page.getRegion(0, page.getPositionCount())); } - return transferToOffHeapPages(probeVecAllocator, slicedPages); + return transferToOffHeapPages(slicedPages); } public List getProbeTypes() @@ -549,7 +542,6 @@ public class BenchmarkHashJoinOmniOperators protected TaskContext createTaskContext() { TaskContext testingTaskContext = createTaskContextBySizeInGigaByte(4); - probeVecAllocator = VecAllocatorHelper.createTaskLevelAllocator(testingTaskContext); return testingTaskContext; } diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkMergeOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkMergeOmniOperator.java index 2b125ff1236435f77ad70994160141ca4a878d34..52f54819aa2f7f8d56aa320005ed23f9462cd8c6 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkMergeOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkMergeOmniOperator.java @@ -46,8 +46,6 @@ import io.prestosql.testing.TestingTaskContext; import nova.hetu.olk.operator.MergeOmniOperator; import nova.hetu.olk.operator.MergeOmniOperator.MergeOmniOperatorFactory; import nova.hetu.olk.tool.BlockUtils; -import nova.hetu.olk.tool.VecAllocatorHelper; -import nova.hetu.omniruntime.vector.VecAllocator; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -130,7 +128,6 @@ public class BenchmarkMergeOmniOperator private MergeOmniOperatorFactory operatorFactory; private TaskContext testingTaskContext; - private VecAllocator taskLevelAllocator; private List pageTemplate; @Param({"group1", "group2", "group3", "group4", "group5", "group6", "group7"}) @@ -208,7 +205,6 @@ public class BenchmarkMergeOmniOperator { TaskContext taskContext = TestingTaskContext.builder(executor, scheduledExecutor, TEST_SESSION) .setQueryMaxMemory(new DataSize(2, GIGABYTE)).setTaskStateMachine(new TaskStateMachine(new TaskId("query", 1, 1), executor)).build(); - taskLevelAllocator = VecAllocatorHelper.createTaskLevelAllocator(taskContext); return taskContext; } @@ -218,7 +214,7 @@ public class BenchmarkMergeOmniOperator for (Page page : pageTemplate) { slicedPages.add(page.getRegion(0, page.getPositionCount())); } - return transferToOffHeapPages(taskLevelAllocator, slicedPages); + return transferToOffHeapPages(slicedPages); } } diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/tool/TestBlockUtils.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/tool/TestBlockUtils.java index 2dc35e8132266e562926c0987be208891f4e9a65..0628975f3a9267b06f7e36d544f481aabd8dcaac 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/tool/TestBlockUtils.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/tool/TestBlockUtils.java @@ -23,7 +23,6 @@ import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.ShortVec; import nova.hetu.omniruntime.vector.VarcharVec; import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecAllocator; import org.junit.runner.RunWith; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareForTest; @@ -31,41 +30,27 @@ import org.powermock.core.classloader.annotations.SuppressStaticInitializationFo import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.testng.PowerMockTestCase; import org.testng.annotations.BeforeMethod; -import org.testng.annotations.Test; -import static nova.hetu.olk.tool.BlockUtils.compactVec; -import static org.mockito.Matchers.anyInt; import static org.powermock.api.mockito.PowerMockito.mock; import static org.powermock.api.mockito.PowerMockito.when; import static org.powermock.api.mockito.PowerMockito.whenNew; -import static org.testng.Assert.assertEquals; @RunWith(PowerMockRunner.class) -@PrepareForTest({VecAllocator.class, - Vec.class, +@PrepareForTest({Vec.class, BlockUtils.class }) -@SuppressStaticInitializationFor({"nova.hetu.omniruntime.vector.VecAllocator", - "nova.hetu.omniruntime.vector.Vec" -}) +@SuppressStaticInitializationFor("nova.hetu.omniruntime.vector.Vec") @PowerMockIgnore("javax.management.*") public class TestBlockUtils extends PowerMockTestCase { BooleanVec booleanVec; - BooleanVec booleanVecRegion; IntVec intVec; - IntVec intVecRegion; ShortVec shortVec; - ShortVec shortVecRegion; LongVec longVec; - LongVec longVecRegion; DoubleVec doubleVec; - DoubleVec doubleVecRegion; VarcharVec varcharVec; - VarcharVec varcharVecRegion; Decimal128Vec decimal128Vec; - Decimal128Vec decimal128VecRegion; @BeforeMethod public void setUp() throws Exception @@ -73,80 +58,34 @@ public class TestBlockUtils mockSupports(); } - @Test - public void testVecCompact() - { - assertEquals(booleanVec, compactVec(booleanVec, 0, 4)); - assertEquals(booleanVecRegion, compactVec(booleanVec, 1, 2)); - - assertEquals(intVec, compactVec(intVec, 0, 4)); - assertEquals(intVecRegion, compactVec(intVec, 1, 2)); - - assertEquals(shortVec, compactVec(shortVec, 0, 4)); - assertEquals(shortVecRegion, compactVec(shortVec, 1, 2)); - - assertEquals(longVec, compactVec(longVec, 0, 4)); - assertEquals(longVecRegion, compactVec(longVec, 1, 2)); - - assertEquals(doubleVec, compactVec(doubleVec, 0, 4)); - assertEquals(doubleVecRegion, compactVec(doubleVec, 1, 2)); - - assertEquals(varcharVec, compactVec(varcharVec, 0, 4)); - assertEquals(varcharVecRegion, compactVec(varcharVec, 1, 2)); - - assertEquals(decimal128Vec, compactVec(decimal128Vec, 0, 4)); - assertEquals(decimal128VecRegion, compactVec(decimal128Vec, 1, 2)); - } - private void mockSupports() throws Exception { booleanVec = mock(BooleanVec.class); whenNew(BooleanVec.class).withAnyArguments().thenReturn(booleanVec); when(booleanVec.getSize()).thenReturn(4); - when(booleanVec.getOffset()).thenReturn(0); - booleanVecRegion = mock(BooleanVec.class); - when(booleanVec.copyRegion(anyInt(), anyInt())).thenReturn(booleanVecRegion); intVec = mock(IntVec.class); whenNew(IntVec.class).withAnyArguments().thenReturn(intVec); when(intVec.getSize()).thenReturn(4); - when(intVec.getOffset()).thenReturn(0); - intVecRegion = mock(IntVec.class); - when(intVec.copyRegion(anyInt(), anyInt())).thenReturn(intVecRegion); shortVec = mock(ShortVec.class); whenNew(ShortVec.class).withAnyArguments().thenReturn(shortVec); when(shortVec.getSize()).thenReturn(4); - when(shortVec.getOffset()).thenReturn(0); - shortVecRegion = mock(ShortVec.class); - when(shortVec.copyRegion(anyInt(), anyInt())).thenReturn(shortVecRegion); longVec = mock(LongVec.class); whenNew(LongVec.class).withAnyArguments().thenReturn(longVec); when(longVec.getSize()).thenReturn(4); - when(longVec.getOffset()).thenReturn(0); - longVecRegion = mock(LongVec.class); - when(longVec.copyRegion(anyInt(), anyInt())).thenReturn(longVecRegion); doubleVec = mock(DoubleVec.class); whenNew(DoubleVec.class).withAnyArguments().thenReturn(doubleVec); when(doubleVec.getSize()).thenReturn(4); - when(doubleVec.getOffset()).thenReturn(0); - doubleVecRegion = mock(DoubleVec.class); - when(doubleVec.copyRegion(anyInt(), anyInt())).thenReturn(doubleVecRegion); varcharVec = mock(VarcharVec.class); whenNew(VarcharVec.class).withAnyArguments().thenReturn(varcharVec); when(varcharVec.getSize()).thenReturn(4); - when(varcharVec.getOffset()).thenReturn(0); - varcharVecRegion = mock(VarcharVec.class); - when(varcharVec.copyRegion(anyInt(), anyInt())).thenReturn(varcharVecRegion); decimal128Vec = mock(Decimal128Vec.class); whenNew(Decimal128Vec.class).withAnyArguments().thenReturn(decimal128Vec); when(decimal128Vec.getSize()).thenReturn(4); - when(decimal128Vec.getOffset()).thenReturn(0); - decimal128VecRegion = mock(Decimal128Vec.class); - when(decimal128Vec.copyRegion(anyInt(), anyInt())).thenReturn(decimal128VecRegion); } } diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/tool/TestOperatorUtils.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/tool/TestOperatorUtils.java index 663f189b10fa5d5f3fbe2abc8b3e21903e391335..63f7d76e8487f66c54524423da53517249c90544 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/tool/TestOperatorUtils.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/tool/TestOperatorUtils.java @@ -58,9 +58,7 @@ import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.ShortVec; import nova.hetu.omniruntime.vector.VarcharVec; import nova.hetu.omniruntime.vector.Vec; -import nova.hetu.omniruntime.vector.VecAllocator; import org.junit.runner.RunWith; -import org.powermock.api.support.membermodification.MemberModifier; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor; @@ -85,6 +83,7 @@ import static io.prestosql.spi.type.SmallintType.SMALLINT; import static io.prestosql.spi.type.TimestampType.TIMESTAMP; import static io.prestosql.spi.type.VarbinaryType.VARBINARY; import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static nova.hetu.olk.mock.MockUtil.mockNewVecWithAnyArguments; import static nova.hetu.olk.tool.OperatorUtils.transferToOffHeapPages; import static nova.hetu.olk.tool.OperatorUtils.transferToOnHeapPage; import static nova.hetu.olk.tool.OperatorUtils.transferToOnHeapPages; @@ -96,11 +95,12 @@ import static org.powermock.api.mockito.PowerMockito.whenNew; import static org.testng.Assert.assertEquals; @RunWith(PowerMockRunner.class) -@PrepareForTest({VecAllocator.class, +@PrepareForTest({ Vec.class, - OperatorUtils.class + OperatorUtils.class, + OmniPage.class }) -@SuppressStaticInitializationFor({"nova.hetu.omniruntime.vector.VecAllocator", +@SuppressStaticInitializationFor({ "nova.hetu.omniruntime.vector.Vec", "nova.hetu.olk.block.RowOmniBlock" }) @@ -125,6 +125,7 @@ public class TestOperatorUtils @BeforeMethod public void setUp() throws Exception { + mockNewVecWithAnyArguments(OmniPage.class); mockSupports(); } @@ -133,10 +134,11 @@ public class TestOperatorUtils { List pages = buildPages(types, false, 100); // transfer on-feap page to off-heap - List offHeapPages = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, pages); + List offHeapPages = transferToOffHeapPages(pages); // transfer off-heap page to on-heap - List onHeapPages = transferToOnHeapPages(offHeapPages); - freeNativeMemory(offHeapPages); + List omniPages = buildOmniPages(); + List onHeapPages = transferToOnHeapPages(omniPages); + freeNativeMemory(omniPages); } @Test @@ -144,10 +146,11 @@ public class TestOperatorUtils { List pages = buildPages(types, true, 100); // transfer on-feap page to off-heap - List offHeapPages = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, pages); + List offHeapPages = transferToOffHeapPages(pages); // transfer off-heap page to on-heap - List onHeapPages = transferToOnHeapPages(offHeapPages); - freeNativeMemory(offHeapPages); + List omniPages = buildOmniPages(); + List onHeapPages = transferToOnHeapPages(omniPages); + freeNativeMemory(omniPages); } @Test @@ -156,11 +159,11 @@ public class TestOperatorUtils Type type = BIGINT; Page page = new Page(buildRowBlockByBuilder(type)); // transfer on-heap page to off-heap - Page offHeapPage = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, page, - ImmutableList.of(RowType.from(ImmutableList.of(RowType.field(type))))); + Page offHeapPage = transferToOffHeapPages(page, ImmutableList.of(RowType.from(ImmutableList.of(RowType.field(type))))); // transfer off-heap page to on-heap - Page onHeapPage = transferToOnHeapPage(offHeapPage); - BlockUtils.freePage(offHeapPage); + Page omniPage = buildOmniPage(); + Page onHeapPage = transferToOnHeapPage(omniPage); + BlockUtils.freePage(omniPage); } @Test @@ -179,18 +182,19 @@ public class TestOperatorUtils runLengthEncodedPages.add(new Page(blocks.toArray(new Block[blocks.size()]))); } // transfer on-heap page to off-heap - List offHeapPages = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, runLengthEncodedPages); + List offHeapPages = transferToOffHeapPages(runLengthEncodedPages); // transfer off-heap page to on-heap - List onHeapPages = transferToOnHeapPages(offHeapPages); - freeNativeMemory(offHeapPages); + List omniPages = buildOmniPages(); + List onHeapPages = transferToOnHeapPages(omniPages); + freeNativeMemory(omniPages); } @Test public void testBlockTypeTransfer() { Page page = buildPages(new ImmutableList.Builder().add(DOUBLE).build(), false, 1).get(0); - Block block = OperatorUtils.buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, page.getBlock(0), "LongArrayBlock", 1, DOUBLE); - OperatorUtils.buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, new RunLengthEncodedBlock(page.getBlock(0), 1), "RunLengthEncodedBlock", 1, DOUBLE); + Block block = OperatorUtils.buildOffHeapBlock(page.getBlock(0), "LongArrayBlock", 1, DOUBLE); + OperatorUtils.buildOffHeapBlock(new RunLengthEncodedBlock(page.getBlock(0), 1), "RunLengthEncodedBlock", 1, DOUBLE); } @Test @@ -236,7 +240,7 @@ public class TestOperatorUtils vecs.add(varcharVec); vecs.add(decimal128Vec); vecs.add(containerVec); - assertEquals(vecs, OperatorUtils.createBlankVectors(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, dataTypes, 1)); + assertEquals(vecs, OperatorUtils.createBlankVectors(dataTypes, 1)); } @Test @@ -266,6 +270,21 @@ public class TestOperatorUtils return pages; } + private List buildOmniPages() + { + List pages = new ArrayList<>(); + Page page = buildOmniPage(); + pages.add(page); + return pages; + } + + private Page buildOmniPage() + { + IntArrayOmniBlock intArrayOmniBlock = new IntArrayOmniBlock(3, intVec); + Page page = new Page(intArrayOmniBlock); + return page; + } + private Block buildRowBlockByBuilder(Type type) { BlockBuilder rowBlockBuilder = type.createBlockBuilder(null, 4); @@ -284,17 +303,12 @@ public class TestOperatorUtils private void mockSupports() throws Exception { - //mock GLOBAL_VECTOR_ALLOCATOR - VecAllocator vecAllocator = mock(VecAllocator.class); - MemberModifier.field(VecAllocator.class, "GLOBAL_VECTOR_ALLOCATOR").set(VecAllocator.class, vecAllocator); - ByteArrayOmniBlock byteArrayOmniBlock = mock(ByteArrayOmniBlock.class); when(byteArrayOmniBlock.isExtensionBlock()).thenReturn(true); when(byteArrayOmniBlock.getPositionCount()).thenReturn(1); whenNew(ByteArrayOmniBlock.class).withAnyArguments().thenReturn(byteArrayOmniBlock); booleanVec = mock(BooleanVec.class, RETURNS_DEEP_STUBS); when(booleanVec.getValuesBuf().getBytes(anyInt(), anyInt())).thenReturn(new byte[]{1}); - when(booleanVec.getOffset()).thenReturn(0); when(booleanVec.getValuesNulls(anyInt(), anyInt())).thenReturn(new boolean[]{true}); whenNew(BooleanVec.class).withAnyArguments().thenReturn(booleanVec); when(byteArrayOmniBlock.getValues()).thenReturn(booleanVec); @@ -354,11 +368,9 @@ public class TestOperatorUtils when(variableWidthOmniBlock.getPositionCount()).thenReturn(1); whenNew(VariableWidthOmniBlock.class).withAnyArguments().thenReturn(variableWidthOmniBlock); varcharVec = mock(VarcharVec.class); - when(varcharVec.hasNullValue()).thenReturn(false); + when(varcharVec.hasNull()).thenReturn(false); when(varcharVec.getValuesNulls(anyInt(), anyInt())).thenReturn(new boolean[]{true}); - when(varcharVec.getValueOffset(anyInt())).thenAnswer(n -> n.getArguments()[0]); - when(varcharVec.getValueOffset(anyInt(), anyInt())).thenReturn(new int[]{0, 1}); - when(varcharVec.getData(anyInt(), anyInt())).thenReturn(new byte[]{1}); + when(varcharVec.get(anyInt(), anyInt())).thenReturn(new byte[]{1}); whenNew(VarcharVec.class).withAnyArguments().thenReturn(varcharVec); when(variableWidthOmniBlock.getValues()).thenReturn(varcharVec); @@ -368,7 +380,7 @@ public class TestOperatorUtils when(dictionaryOmniBlock.getDictionary()).thenReturn(byteArrayOmniBlock); whenNew(DictionaryOmniBlock.class).withAnyArguments().thenReturn(dictionaryOmniBlock); dictionaryVec = mock(DictionaryVec.class); - when(dictionaryVec.getIds(anyInt())).thenReturn(new int[]{1}); + when(dictionaryVec.getIds()).thenReturn(new int[]{1}); when(dictionaryVec.getValuesNulls(anyInt(), anyInt())).thenReturn(new boolean[]{true}); whenNew(DictionaryVec.class).withAnyArguments().thenReturn(dictionaryVec); when(dictionaryOmniBlock.getValues()).thenReturn(dictionaryVec); diff --git a/omnioperator/omniop-spark-extension-ock/cpp/CMakeLists.txt b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/CMakeLists.txt similarity index 97% rename from omnioperator/omniop-spark-extension-ock/cpp/CMakeLists.txt rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/CMakeLists.txt index 92d57e99819f7b21e42a01d242b044e8f667fe12..86d401d8384bb36b65aa75b6c10dde7abb74f8ba 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/CMakeLists.txt @@ -7,7 +7,7 @@ set(CMAKE_VERBOSE_MAKEFILE ON) cmake_minimum_required(VERSION 3.10) # configure cmake -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(root_directory ${PROJECT_BINARY_DIR}) diff --git a/omnioperator/omniop-spark-extension-ock/cpp/build.sh b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/build.sh similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/build.sh rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/build.sh diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/CMakeLists.txt b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/CMakeLists.txt similarity index 96% rename from omnioperator/omniop-spark-extension-ock/cpp/src/CMakeLists.txt rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/CMakeLists.txt index 4e3c3e2160cc8415575ed5c6745f29ac60fc298b..27a927fdb7c0fceae786683ccced1f396af59d1a 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/CMakeLists.txt @@ -38,8 +38,7 @@ target_include_directories(${PROJ_TARGET} PUBLIC /opt/lib/include) target_link_libraries (${PROJ_TARGET} PUBLIC protobuf.a z - boostkit-omniop-runtime-1.1.0-aarch64 - boostkit-omniop-vector-1.1.0-aarch64 + boostkit-omniop-vector-1.3.0-aarch64 ock_shuffle gcov ) diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/common/common.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/common/common.h similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/src/common/common.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/common/common.h diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/common/debug.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/common/debug.h similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/src/common/debug.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/common/debug.h diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniReader.cpp similarity index 75% rename from omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniReader.cpp index 456519e9a8ee7edac294289f84273244f50c9d62..21e482c8d2f2b3457d1167c83c3aaa7e7fc09da1 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.cpp +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniReader.cpp @@ -10,6 +10,7 @@ #include "OckShuffleJniReader.h" using namespace omniruntime::vec; +using namespace omniruntime::type; using namespace ock::dopspark; static std::mutex gInitLock; @@ -20,11 +21,16 @@ static const char *exceptionClass = "java/lang/Exception"; static void JniInitialize(JNIEnv *env) { + if (UNLIKELY(env ==nullptr)) { + LOG_ERROR("JNIEnv is null."); + return; + } std::lock_guard lk(gInitLock); if (UNLIKELY(gLongClass == nullptr)) { gLongClass = env->FindClass("java/lang/Long"); if (UNLIKELY(gLongClass == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), "Failed to find class java/lang/Long"); + return; } gLongValueFieldId = env->GetFieldID(gLongClass, "value", "J"); @@ -38,24 +44,53 @@ static void JniInitialize(JNIEnv *env) JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_make(JNIEnv *env, jobject, jintArray jTypeIds) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return 0; + } + if (UNLIKELY(jTypeIds == nullptr)) { + env->ThrowNew(env->FindClass(exceptionClass), "jTypeIds is null."); + return 0; + } std::shared_ptr instance = std::make_shared(); if (UNLIKELY(instance == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), "Failed to create instance for ock merge reader"); return 0; } - bool result = instance->Initialize(env->GetIntArrayElements(jTypeIds, nullptr), env->GetArrayLength(jTypeIds)); + auto typeIds = env->GetIntArrayElements(jTypeIds, nullptr); + if (UNLIKELY(typeIds == nullptr)) { + env->ThrowNew(env->FindClass(exceptionClass), "Failed to get int array elements."); + return 0; + } + bool result = instance->Initialize(typeIds, env->GetArrayLength(jTypeIds)); if (UNLIKELY(!result)) { + env->ReleaseIntArrayElements(jTypeIds, typeIds, JNI_ABORT); env->ThrowNew(env->FindClass(exceptionClass), "Failed to initialize ock merge reader"); return 0; } - + env->ReleaseIntArrayElements(jTypeIds, typeIds, JNI_ABORT); return gBlobReader.Insert(instance); } +JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_close(JNIEnv *env, jobject, jlong jReaderId) +{ + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIENV is null."); + return; + } + + gBlobReader.Erase(jReaderId); +} + JNIEXPORT jint JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_nativeGetVectorBatch(JNIEnv *env, jobject, jlong jReaderId, jlong jAddress, jint jRemain, jint jMaxRow, jint jMaxSize, jobject jRowCnt) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return -1; + } + auto mergeReader = gBlobReader.Lookup(jReaderId); if (UNLIKELY(!mergeReader)) { std::string errMsg = "Invalid reader id " + std::to_string(jReaderId); @@ -80,6 +115,10 @@ JNIEXPORT jint JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_nativeG JNIEXPORT jint JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_nativeGetVecValueLength(JNIEnv *env, jobject, jlong jReaderId, jint jColIndex) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return 0; + } auto mergeReader = gBlobReader.Lookup(jReaderId); if (UNLIKELY(!mergeReader)) { std::string errMsg = "Invalid reader id " + std::to_string(jReaderId); @@ -100,7 +139,12 @@ JNIEXPORT jint JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_nativeG JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_nativeCopyVecDataInVB(JNIEnv *env, jobject, jlong jReaderId, jlong dstNativeVec, jint jColIndex) { - auto dstVector = reinterpret_cast(dstNativeVec); // get from scala which is real vector + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return; + } + + auto dstVector = reinterpret_cast(dstNativeVec); // get from scala which is real vector if (UNLIKELY(dstVector == nullptr)) { std::string errMsg = "Invalid dst vector address for reader id " + std::to_string(jReaderId); env->ThrowNew(env->FindClass(exceptionClass), errMsg.c_str()); diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniReader.h similarity index 86% rename from omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniReader.h index 80a63c403ef8ce43ee5be522ab6bfd5fea6c9b37..eb8a692a7dcde68fafed60820da44870c3fc3a3e 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.h +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniReader.h @@ -18,6 +18,12 @@ extern "C" { */ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_make(JNIEnv *, jobject, jintArray); +/* + * Class: com_huawei_ock_spark_jni_OckShuffleJniReader + * Method: close + * Signature: (JI)I + */ +JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_close(JNIEnv *, jobject, jlong); /* * Class: com_huawei_ock_spark_jni_OckShuffleJniReader * Method: nativeGetVectorBatch diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniWriter.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniWriter.cpp similarity index 86% rename from omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniWriter.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniWriter.cpp index 61633605eb8afbf26abeeea595fcfc48742f3498..e1bcdec442798804d80ba6bb51ca88f0ce74cc19 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniWriter.cpp +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniWriter.cpp @@ -20,11 +20,15 @@ static const char *exceptionClass = "java/lang/Exception"; JNIEXPORT jboolean JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_initialize(JNIEnv *env, jobject) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return JNI_FALSE; + } gSplitResultClass = CreateGlobalClassReference(env, "Lcom/huawei/boostkit/spark/vectorized/SplitResult;"); gSplitResultConstructor = GetMethodID(env, gSplitResultClass, "", "(JJJJJ[J)V"); if (UNLIKELY(!OckShuffleSdk::Initialize())) { - std::cout << "Failed to load ock shuffle library." << std::endl; + env->ThrowNew(env->FindClass(exceptionClass), std::string("Failed to load ock shuffle library.").c_str()); return JNI_FALSE; } @@ -36,9 +40,14 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native jstring jPartitioningMethod, jint jPartitionNum, jstring jColTypes, jint jColNum, jint jRegionSize, jint jMinCapacity, jint jMaxCapacity, jboolean jIsCompress) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return 0; + } auto appIdStr = env->GetStringUTFChars(jAppId, JNI_FALSE); if (UNLIKELY(appIdStr == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("ApplicationId can't be empty").c_str()); + return 0; } auto appId = std::string(appIdStr); env->ReleaseStringUTFChars(jAppId, appIdStr); @@ -46,6 +55,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native auto partitioningMethodStr = env->GetStringUTFChars(jPartitioningMethod, JNI_FALSE); if (UNLIKELY(partitioningMethodStr == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("Partitioning method can't be empty").c_str()); + return 0; } auto partitionMethod = std::string(partitioningMethodStr); env->ReleaseStringUTFChars(jPartitioningMethod, partitioningMethodStr); @@ -53,6 +63,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native auto colTypesStr = env->GetStringUTFChars(jColTypes, JNI_FALSE); if (UNLIKELY(colTypesStr == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("Columns types can't be empty").c_str()); + return 0; } DataTypes colTypes = Deserialize(colTypesStr); @@ -63,7 +74,8 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native jmethodID jMethodId = env->GetStaticMethodID(jThreadCls, "currentThread", "()Ljava/lang/Thread;"); jobject jThread = env->CallStaticObjectMethod(jThreadCls, jMethodId); if (UNLIKELY(jThread == nullptr)) { - std::cout << "Failed to get current thread instance." << std::endl; + env->ThrowNew(env->FindClass(exceptionClass), std::string("Failed to get current thread instance.").c_str()); + return 0; } else { jThreadId = env->CallLongMethod(jThread, env->GetMethodID(jThreadCls, "getId", "()J")); } @@ -71,16 +83,19 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native auto splitter = OckSplitter::Make(partitionMethod, jPartitionNum, colTypes.GetIds(), jColNum, (uint64_t)jThreadId); if (UNLIKELY(splitter == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("Failed to make ock splitter").c_str()); + return 0; } bool ret = splitter->SetShuffleInfo(appId, jShuffleId, jStageId, jStageAttemptNum, jMapId, jTaskAttemptId); if (UNLIKELY(!ret)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("Failed to set shuffle information").c_str()); + return 0; } ret = splitter->InitLocalBuffer(jRegionSize, jMinCapacity, jMaxCapacity, (jIsCompress == JNI_TRUE)); if (UNLIKELY(!ret)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("Failed to initialize local buffer").c_str()); + return 0; } return gOckSplitterMap.Insert(std::shared_ptr(splitter)); @@ -89,21 +104,28 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_split(JNIEnv *env, jobject, jlong splitterId, jlong nativeVectorBatch) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return; + } auto splitter = gOckSplitterMap.Lookup(splitterId); if (UNLIKELY(!splitter)) { std::string errMsg = "Invalid splitter id " + std::to_string(splitterId); env->ThrowNew(env->FindClass(exceptionClass), errMsg.c_str()); + return; } auto vecBatch = (VectorBatch *)nativeVectorBatch; if (UNLIKELY(vecBatch == nullptr)) { std::string errMsg = "Invalid address for native vector batch."; env->ThrowNew(env->FindClass(exceptionClass), errMsg.c_str()); + return; } if (UNLIKELY(!splitter->Split(*vecBatch))) { std::string errMsg = "Failed to split vector batch by splitter id " + std::to_string(splitterId); env->ThrowNew(env->FindClass(exceptionClass), errMsg.c_str()); + return; } delete vecBatch; @@ -112,10 +134,15 @@ JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_split(J JNIEXPORT jobject JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_stop(JNIEnv *env, jobject, jlong splitterId) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return nullptr; + } auto splitter = gOckSplitterMap.Lookup(splitterId); if (UNLIKELY(!splitter)) { std::string error_message = "Invalid splitter id " + std::to_string(splitterId); env->ThrowNew(env->FindClass(exceptionClass), error_message.c_str()); + return nullptr; } splitter->Stop(); // free resource @@ -132,10 +159,15 @@ JNIEXPORT jobject JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_stop JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_close(JNIEnv *env, jobject, jlong splitterId) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return; + } auto splitter = gOckSplitterMap.Lookup(splitterId); if (UNLIKELY(!splitter)) { std::string errMsg = "Invalid splitter id " + std::to_string(splitterId); env->ThrowNew(env->FindClass(exceptionClass), errMsg.c_str()); + return; } gOckSplitterMap.Erase(splitterId); diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniWriter.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniWriter.h similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniWriter.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniWriter.h diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/concurrent_map.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/concurrent_map.h similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/src/jni/concurrent_map.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/concurrent_map.h diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/jni_common.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/jni_common.h similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/src/jni/jni_common.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/jni_common.h diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/proto/vec_data.proto b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/proto/vec_data.proto similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/src/proto/vec_data.proto rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/proto/vec_data.proto diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/sdk/ock_shuffle_sdk.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/sdk/ock_shuffle_sdk.h similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/src/sdk/ock_shuffle_sdk.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/sdk/ock_shuffle_sdk.h diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_hash_write_buffer.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_hash_write_buffer.cpp similarity index 81% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_hash_write_buffer.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_hash_write_buffer.cpp index b9c6ced10a6742812d257ee6cf95c84b9e5b3ad0..d0fe8198b4eb15f8796e2e70ce4480761180cf59 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_hash_write_buffer.cpp +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_hash_write_buffer.cpp @@ -23,9 +23,21 @@ bool OckHashWriteBuffer::Initialize(uint32_t regionSize, uint32_t minCapacity, u mIsCompress = isCompress; uint32_t bufferNeed = regionSize * mPartitionNum; mDataCapacity = std::min(std::max(bufferNeed, minCapacity), maxCapacity); + if (UNLIKELY(mDataCapacity < mSinglePartitionAndRegionUsedSize * mPartitionNum)) { + LogError("mDataCapacity should be bigger than mSinglePartitionAndRegionUsedSize * mPartitionNum"); + return false; + } mRegionPtRecordOffset = mDataCapacity - mSinglePartitionAndRegionUsedSize * mPartitionNum; + if (UNLIKELY(mDataCapacity < mSingleRegionUsedSize * mPartitionNum)) { + LogError("mDataCapacity should be bigger than mSingleRegionUsedSize * mPartitionNum"); + return false; + } mRegionUsedRecordOffset = mDataCapacity - mSingleRegionUsedSize * mPartitionNum; + if (UNLIKELY(mDataCapacity / mPartitionNum < mSinglePartitionAndRegionUsedSize)) { + LogError("mDataCapacity / mPartitionNum should be bigger than mSinglePartitionAndRegionUsedSize"); + return false; + } mEachPartitionSize = mDataCapacity / mPartitionNum - mSinglePartitionAndRegionUsedSize; mDoublePartitionSize = reserveSize * mEachPartitionSize; @@ -76,6 +88,10 @@ OckHashWriteBuffer::ResultFlag OckHashWriteBuffer::PreoccupiedDataSpace(uint32_t return ResultFlag::UNEXPECTED; } + if (UNLIKELY(mTotalSize > UINT32_MAX - length)) { + LogError("mTotalSize + length exceed UINT32_MAX"); + return ResultFlag::UNEXPECTED; + } // 1. get the new region id for partitionId uint32_t regionId = UINT32_MAX; if (newRegion && !GetNewRegion(partitionId, regionId)) { @@ -98,7 +114,7 @@ OckHashWriteBuffer::ResultFlag OckHashWriteBuffer::PreoccupiedDataSpace(uint32_t (mDoublePartitionSize - mRegionUsedSize[regionId] - mRegionUsedSize[nearRegionId]); if (remainBufLength >= length) { mRegionUsedSize[regionId] += length; - mTotalSize += length; // todo check + mTotalSize += length; return ResultFlag::ENOUGH; } @@ -111,8 +127,16 @@ uint8_t *OckHashWriteBuffer::GetEndAddressOfRegion(uint32_t partitionId, uint32_ regionId = mPtCurrentRegionId[partitionId]; if ((regionId % groupSize) == 0) { + if (UNLIKELY(regionId * mEachPartitionSize + mRegionUsedSize[regionId] < length)) { + LogError("regionId * mEachPartitionSize + mRegionUsedSize[regionId] shoulld be bigger than length"); + return nullptr; + } offset = regionId * mEachPartitionSize + mRegionUsedSize[regionId] - length; } else { + if (UNLIKELY((regionId + 1) * mEachPartitionSize < mRegionUsedSize[regionId])) { + LogError("(regionId + 1) * mEachPartitionSize shoulld be bigger than mRegionUsedSize[regionId]"); + return nullptr; + } offset = (regionId + 1) * mEachPartitionSize - mRegionUsedSize[regionId]; } diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_hash_write_buffer.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_hash_write_buffer.h similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_hash_write_buffer.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_hash_write_buffer.h diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_merge_reader.cpp similarity index 52% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_merge_reader.cpp index 80ff1737977846dee4dad93049c35ffb44509f13..d1ef824c4a3032e3305ac5d7b16cc7838f5f8684 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.cpp +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_merge_reader.cpp @@ -8,19 +8,23 @@ #include "common/common.h" -using namespace omniruntime::type; using namespace omniruntime::vec; using namespace ock::dopspark; bool OckMergeReader::Initialize(const int32_t *typeIds, uint32_t colNum) { mColNum = colNum; - mVectorBatch = new (std::nothrow) VBDataDesc(colNum); + mVectorBatch = std::make_shared(); if (UNLIKELY(mVectorBatch == nullptr)) { LOG_ERROR("Failed to new instance for vector batch description"); return false; } + if (UNLIKELY(!mVectorBatch->Initialize(colNum))) { + LOG_ERROR("Failed to initialize vector batch."); + return false; + } + mColTypeIds.reserve(colNum); for (uint32_t index = 0; index < colNum; ++index) { mColTypeIds.emplace_back(typeIds[index]); @@ -29,44 +33,48 @@ bool OckMergeReader::Initialize(const int32_t *typeIds, uint32_t colNum) return true; } -bool OckMergeReader::GenerateVector(OckVector &vector, uint32_t rowNum, int32_t typeId, uint8_t *&startAddress) +bool OckMergeReader::GenerateVector(OckVectorPtr &vector, uint32_t rowNum, int32_t typeId, uint8_t *&startAddress) { uint8_t *address = startAddress; - vector.SetValueNulls(static_cast(address)); - vector.SetSize(rowNum); + vector->SetValueNulls(static_cast(address)); + vector->SetSize(rowNum); address += rowNum; switch (typeId) { case OMNI_BOOLEAN: { - vector.SetCapacityInBytes(sizeof(uint8_t) * rowNum); + vector->SetCapacityInBytes(sizeof(uint8_t) * rowNum); break; } case OMNI_SHORT: { - vector.SetCapacityInBytes(sizeof(uint16_t) * rowNum); + vector->SetCapacityInBytes(sizeof(uint16_t) * rowNum); break; } case OMNI_INT: case OMNI_DATE32: { - vector.SetCapacityInBytes(sizeof(uint32_t) * rowNum); + vector->SetCapacityInBytes(sizeof(uint32_t) * rowNum); break; } case OMNI_LONG: case OMNI_DOUBLE: case OMNI_DECIMAL64: case OMNI_DATE64: { - vector.SetCapacityInBytes(sizeof(uint64_t) * rowNum); + vector->SetCapacityInBytes(sizeof(uint64_t) * rowNum); break; } case OMNI_DECIMAL128: { - vector.SetCapacityInBytes(decimal128Size * rowNum); // 16 means value cost 16Byte + vector->SetCapacityInBytes(decimal128Size * rowNum); // 16 means value cost 16Byte break; } case OMNI_CHAR: case OMNI_VARCHAR: { // unknown length for value vector, calculate later // will add offset_vector_len when the length of values_vector is variable - vector.SetValueOffsets(static_cast(address)); + vector->SetValueOffsets(static_cast(address)); address += capacityOffset * (rowNum + 1); // 4 means value cost 4Byte - vector.SetCapacityInBytes(*reinterpret_cast(address - capacityOffset)); + vector->SetCapacityInBytes(*reinterpret_cast(address - capacityOffset)); + if (UNLIKELY(vector->GetCapacityInBytes() > maxCapacityInBytes)) { + LOG_ERROR("vector capacityInBytes exceed maxCapacityInBytes"); + return false; + } break; } default: { @@ -75,26 +83,26 @@ bool OckMergeReader::GenerateVector(OckVector &vector, uint32_t rowNum, int32_t } } - vector.SetValues(static_cast(address)); - address += vector.GetCapacityInBytes(); + vector->SetValues(static_cast(address)); + address += vector->GetCapacityInBytes(); startAddress = address; return true; } bool OckMergeReader::CalVectorValueLength(uint32_t colIndex, uint32_t &length) { - OckVector *vector = mVectorBatch->mColumnsHead[colIndex]; + auto vector = mVectorBatch->GetColumnHead(colIndex); + length = 0; for (uint32_t cnt = 0; cnt < mMergeCnt; ++cnt) { if (UNLIKELY(vector == nullptr)) { LOG_ERROR("Failed to calculate value length for column index %d", colIndex); return false; } - - mVectorBatch->mVectorValueLength[colIndex] += vector->GetCapacityInBytes(); + length += vector->GetCapacityInBytes(); vector = vector->GetNextVector(); } - length = mVectorBatch->mVectorValueLength[colIndex]; + mVectorBatch->SetColumnCapacity(colIndex, length); return true; } @@ -102,37 +110,27 @@ bool OckMergeReader::ScanOneVectorBatch(uint8_t *&startAddress) { uint8_t *address = startAddress; // get vector batch msg as vb_data_batch memory layout (upper) - mCurVBHeader = reinterpret_cast(address); - mVectorBatch->mHeader.rowNum += mCurVBHeader->rowNum; - mVectorBatch->mHeader.length += mCurVBHeader->length; + auto curVBHeader = reinterpret_cast(address); + mVectorBatch->AddTotalCapacity(curVBHeader->length); + mVectorBatch->AddTotalRowNum(curVBHeader->rowNum); address += sizeof(struct VBDataHeaderDesc); OckVector *curVector = nullptr; for (uint32_t colIndex = 0; colIndex < mColNum; colIndex++) { - curVector = mVectorBatch->mColumnsCur[colIndex]; - if (UNLIKELY(!GenerateVector(*curVector, mCurVBHeader->rowNum, mColTypeIds[colIndex], address))) { - LOG_ERROR("Failed to generate vector"); + auto curVector = mVectorBatch->GetCurColumn(colIndex); + if (UNLIKELY(curVector == nullptr)) { + LOG_ERROR("curVector is null, index %d", colIndex); return false; } - - if (curVector->GetNextVector() == nullptr) { - curVector = new (std::nothrow) OckVector(); - if (UNLIKELY(curVector == nullptr)) { - LOG_ERROR("Failed to new instance for ock vector"); - return false; - } - - // set next vector in the column merge list, and current column vector point to it - mVectorBatch->mColumnsCur[colIndex]->SetNextVector(curVector); - mVectorBatch->mColumnsCur[colIndex] = curVector; - } else { - mVectorBatch->mColumnsCur[colIndex] = curVector->GetNextVector(); + if (UNLIKELY(!GenerateVector(curVector, curVBHeader->rowNum, mColTypeIds[colIndex], address))) { + LOG_ERROR("Failed to generate vector"); + return false; } } - if (UNLIKELY((uint32_t)(address - startAddress) != mCurVBHeader->length)) { + if (UNLIKELY((uint32_t)(address - startAddress) != curVBHeader->length)) { LOG_ERROR("Failed to scan one vector batch as invalid date setting %d vs %d", - (uint32_t)(address - startAddress), mCurVBHeader->length); + (uint32_t)(address - startAddress), curVBHeader->length); return false; } @@ -159,49 +157,72 @@ bool OckMergeReader::GetMergeVectorBatch(uint8_t *&startAddress, uint32_t remain } mMergeCnt++; - if (mVectorBatch->mHeader.rowNum >= maxRowNum || mVectorBatch->mHeader.length >= maxSize) { + if (mVectorBatch->GetTotalRowNum() >= maxRowNum || mVectorBatch->GetTotalCapacity() >= maxSize) { break; } } startAddress = address; - return true; } -bool OckMergeReader::CopyPartDataToVector(uint8_t *&nulls, uint8_t *&values, - OckVector &srcVector, uint32_t colIndex) +bool OckMergeReader::CopyPartDataToVector(uint8_t *&nulls, uint8_t *&values, uint32_t &remainingSize, + uint32_t &remainingCapacity, OckVectorPtr &srcVector) { - errno_t ret = memcpy_s(nulls, srcVector.GetSize(), srcVector.GetValueNulls(), srcVector.GetSize()); + uint32_t srcSize = srcVector->GetSize(); + if (UNLIKELY(remainingSize < srcSize)) { + LOG_ERROR("Not eneough resource. remainingSize %d, srcSize %d.", remainingSize, srcSize); + return false; + } + errno_t ret = memcpy_s(nulls, remainingSize, srcVector->GetValueNulls(), srcSize); if (UNLIKELY(ret != EOK)) { LOG_ERROR("Failed to copy null vector"); return false; } - nulls += srcVector.GetSize(); + nulls += srcSize; + remainingSize -= srcSize; - if (srcVector.GetCapacityInBytes() > 0) { - ret = memcpy_s(values, srcVector.GetCapacityInBytes(), srcVector.GetValues(), - srcVector.GetCapacityInBytes()); + uint32_t srcCapacity = srcVector->GetCapacityInBytes(); + if (UNLIKELY(remainingCapacity < srcCapacity)) { + LOG_ERROR("Not enough resource. remainingCapacity %d, srcCapacity %d", remainingCapacity, srcCapacity); + return false; + } + if (srcCapacity > 0) { + ret = memcpy_s(values, remainingCapacity, srcVector->GetValues(), srcCapacity); if (UNLIKELY(ret != EOK)) { LOG_ERROR("Failed to copy values vector"); return false; } - values += srcVector.GetCapacityInBytes(); + values += srcCapacity; + remainingCapacity -=srcCapacity; } return true; } -bool OckMergeReader::CopyDataToVector(Vector *dstVector, uint32_t colIndex) +bool OckMergeReader::CopyDataToVector(BaseVector *dstVector, uint32_t colIndex) { // point to first src vector in list - OckVector *srcVector = mVectorBatch->mColumnsHead[colIndex]; + auto srcVector = mVectorBatch->GetColumnHead(colIndex); - auto *nullsAddress = (uint8_t *)dstVector->GetValueNulls(); - auto *valuesAddress = (uint8_t *)dstVector->GetValues(); - uint32_t *offsetsAddress = (uint32_t *)dstVector->GetValueOffsets(); + auto *nullsAddress = (uint8_t *)omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(dstVector); + auto *valuesAddress = (uint8_t *)VectorHelper::UnsafeGetValues(dstVector); + uint32_t *offsetsAddress = (uint32_t *)VectorHelper::UnsafeGetOffsetsAddr(dstVector); + dstVector->SetNullFlag(true); uint32_t totalSize = 0; uint32_t currentSize = 0; + if (dstVector->GetSize() < 0) { + LOG_ERROR("Invalid vector size %d", dstVector->GetSize()); + return false; + } + uint32_t remainingSize = (uint32_t)dstVector->GetSize(); + uint32_t remainingCapacity = 0; + if (mColTypeIds[colIndex] == OMNI_CHAR || mColTypeIds[colIndex] == OMNI_VARCHAR) { + auto *varCharVector = reinterpret_cast> *>(dstVector); + remainingCapacity = omniruntime::vec::unsafe::UnsafeStringVector::GetContainer(varCharVector)->GetCapacityInBytes(); + } else { + remainingCapacity = GetDataSize(colIndex) * remainingSize; + } for (uint32_t cnt = 0; cnt < mMergeCnt; ++cnt) { if (UNLIKELY(srcVector == nullptr)) { @@ -209,7 +230,7 @@ bool OckMergeReader::CopyDataToVector(Vector *dstVector, uint32_t colIndex) return false; } - if (UNLIKELY(!CopyPartDataToVector(nullsAddress, valuesAddress, *srcVector, colIndex))) { + if (UNLIKELY(!CopyPartDataToVector(nullsAddress, valuesAddress, remainingSize, remainingCapacity, srcVector))) { return false; } @@ -226,9 +247,9 @@ bool OckMergeReader::CopyDataToVector(Vector *dstVector, uint32_t colIndex) if (mColTypeIds[colIndex] == OMNI_CHAR || mColTypeIds[colIndex] == OMNI_VARCHAR) { *offsetsAddress = totalSize; - if (UNLIKELY(totalSize != mVectorBatch->mVectorValueLength[colIndex])) { + if (UNLIKELY(totalSize != mVectorBatch->GetColumnCapacity(colIndex))) { LOG_ERROR("Failed to calculate variable vector value length, %d to %d", totalSize, - mVectorBatch->mVectorValueLength[colIndex]); + mVectorBatch->GetColumnCapacity(colIndex)); return false; } } diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_merge_reader.h similarity index 47% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_merge_reader.h index b5d5fba4d7ddd910146126201cc27776f6ad813b..838dd6a8d6e78b3557764869f1240c47b48aa398 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.h +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_merge_reader.h @@ -10,38 +10,69 @@ namespace ock { namespace dopspark { +using namespace omniruntime::type; class OckMergeReader { public: bool Initialize(const int32_t *typeIds, uint32_t colNum); bool GetMergeVectorBatch(uint8_t *&address, uint32_t remain, uint32_t maxRowNum, uint32_t maxSize); - bool CopyPartDataToVector(uint8_t *&nulls, uint8_t *&values, OckVector &srcVector, uint32_t colIndex); - bool CopyDataToVector(omniruntime::vec::Vector *dstVector, uint32_t colIndex); + bool CopyPartDataToVector(uint8_t *&nulls, uint8_t *&values, uint32_t &remainingSize, uint32_t &remainingCapacity, + OckVectorPtr &srcVector); + bool CopyDataToVector(omniruntime::vec::BaseVector *dstVector, uint32_t colIndex); [[nodiscard]] inline uint32_t GetVectorBatchLength() const { - return mVectorBatch->mHeader.length; + return mVectorBatch->GetTotalCapacity(); } [[nodiscard]] inline uint32_t GetRowNumAfterMerge() const { - return mVectorBatch->mHeader.rowNum; + return mVectorBatch->GetTotalRowNum(); } bool CalVectorValueLength(uint32_t colIndex, uint32_t &length); + inline uint32_t GetDataSize(int32_t colIndex) + { + switch (mColTypeIds[colIndex]) { + case OMNI_BOOLEAN: { + return sizeof(uint8_t); + } + case OMNI_SHORT: { + return sizeof(uint16_t); + } + case OMNI_INT: + case OMNI_DATE32: { + return sizeof(uint32_t); + } + case OMNI_LONG: + case OMNI_DOUBLE: + case OMNI_DECIMAL64: + case OMNI_DATE64: { + return sizeof(uint64_t); + } + case OMNI_DECIMAL128: { + return decimal128Size; + } + default: { + LOG_ERROR("Unsupported data type id %d", mColTypeIds[colIndex]); + return false; + } + } + } + private: - static bool GenerateVector(OckVector &vector, uint32_t rowNum, int32_t typeId, uint8_t *&startAddress); + static bool GenerateVector(OckVectorPtr &vector, uint32_t rowNum, int32_t typeId, uint8_t *&startAddress); bool ScanOneVectorBatch(uint8_t *&startAddress); static constexpr int capacityOffset = 4; static constexpr int decimal128Size = 16; + static constexpr int maxCapacityInBytes = 1073741824; private: // point to shuffle blob current vector batch data header uint32_t mColNum = 0; uint32_t mMergeCnt = 0; std::vector mColTypeIds {}; - VBHeaderPtr mCurVBHeader = nullptr; VBDataDescPtr mVectorBatch = nullptr; }; } diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_splitter.cpp similarity index 65% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_splitter.cpp index 5c046686755c88ccf3e0bdb39e70633c49015aca..fe83d01786666df18422ea31efd6c91638fb8e52 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.cpp +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_splitter.cpp @@ -23,39 +23,49 @@ bool OckSplitter::ToSplitterTypeId(const int32_t *vBColTypes) for (uint32_t colIndex = 0; colIndex < mColNum; ++colIndex) { switch (vBColTypes[colIndex]) { case OMNI_BOOLEAN: { - mVBColShuffleTypes.emplace_back(ShuffleTypeId::SHUFFLE_1BYTE); - mMinDataLenInVBByRow += uint8Size; + CastOmniToShuffleType(OMNI_BOOLEAN, ShuffleTypeId::SHUFFLE_1BYTE, uint8Size); break; } case OMNI_SHORT: { - mVBColShuffleTypes.emplace_back(ShuffleTypeId::SHUFFLE_2BYTE); - mMinDataLenInVBByRow += uint16Size; + CastOmniToShuffleType(OMNI_SHORT, ShuffleTypeId::SHUFFLE_2BYTE, uint16Size); + break; + } + case OMNI_DATE32: { + CastOmniToShuffleType(OMNI_DATE32, ShuffleTypeId::SHUFFLE_4BYTE, uint32Size); break; } - case OMNI_DATE32: case OMNI_INT: { - mVBColShuffleTypes.emplace_back(ShuffleTypeId::SHUFFLE_4BYTE); - mMinDataLenInVBByRow += uint32Size; // 4 means value cost 4Byte + CastOmniToShuffleType(OMNI_INT, ShuffleTypeId::SHUFFLE_4BYTE, uint32Size); + break; + } + case OMNI_DATE64: { + CastOmniToShuffleType(OMNI_DATE64, ShuffleTypeId::SHUFFLE_8BYTE, uint64Size); + break; + } + case OMNI_DOUBLE: { + CastOmniToShuffleType(OMNI_DOUBLE, ShuffleTypeId::SHUFFLE_8BYTE, uint64Size); + break; + } + case OMNI_DECIMAL64: { + CastOmniToShuffleType(OMNI_DECIMAL64, ShuffleTypeId::SHUFFLE_8BYTE, uint64Size); break; } - case OMNI_DATE64: - case OMNI_DOUBLE: - case OMNI_DECIMAL64: case OMNI_LONG: { - mVBColShuffleTypes.emplace_back(ShuffleTypeId::SHUFFLE_8BYTE); - mMinDataLenInVBByRow += uint64Size; // 8 means value cost 8Byte + CastOmniToShuffleType(OMNI_LONG, ShuffleTypeId::SHUFFLE_8BYTE, uint64Size); + break; + } + case OMNI_CHAR: { + CastOmniToShuffleType(OMNI_CHAR, ShuffleTypeId::SHUFFLE_BINARY, uint32Size); + mColIndexOfVarVec.emplace_back(colIndex); break; } - case OMNI_CHAR: - case OMNI_VARCHAR: { // unknown length for value vector, calculate later - mMinDataLenInVBByRow += uint32Size; // 4 means offset - mVBColShuffleTypes.emplace_back(ShuffleTypeId::SHUFFLE_BINARY); + case OMNI_VARCHAR: { // unknown length for value vector, calculate later + CastOmniToShuffleType(OMNI_VARCHAR, ShuffleTypeId::SHUFFLE_BINARY, uint32Size); mColIndexOfVarVec.emplace_back(colIndex); break; } case OMNI_DECIMAL128: { - mVBColShuffleTypes.emplace_back(ShuffleTypeId::SHUFFLE_DECIMAL128); - mMinDataLenInVBByRow += decimal128Size; // 16 means value cost 8Byte + CastOmniToShuffleType(OMNI_DECIMAL128, ShuffleTypeId::SHUFFLE_DECIMAL128, decimal128Size); break; } default: { @@ -70,11 +80,15 @@ bool OckSplitter::ToSplitterTypeId(const int32_t *vBColTypes) return true; } -void OckSplitter::InitCacheRegion() +bool OckSplitter::InitCacheRegion() { mCacheRegion.reserve(mPartitionNum); mCacheRegion.resize(mPartitionNum); + if (UNLIKELY(mOckBuffer->GetRegionSize() * 2 < mMinDataLenInVB || mMinDataLenInVBByRow == 0)) { + LOG_DEBUG("regionSize * doubleNum should be bigger than mMinDataLenInVB %d", mMinDataLenInVBByRow); + return false; + } uint32_t rowNum = (mOckBuffer->GetRegionSize() * 2 - mMinDataLenInVB) / mMinDataLenInVBByRow; LOG_INFO("Each region can cache row number is %d", rowNum); @@ -84,6 +98,7 @@ void OckSplitter::InitCacheRegion() region.mLength = 0; region.mRowNum = 0; } + return true; } bool OckSplitter::Initialize(const int32_t *colTypeIds) @@ -122,6 +137,10 @@ std::shared_ptr OckSplitter::Create(const int32_t *colTypeIds, int3 std::shared_ptr OckSplitter::Make(const std::string &partitionMethod, int partitionNum, const int32_t *colTypeIds, int32_t colNum, uint64_t threadId) { + if (UNLIKELY(colTypeIds == nullptr || colNum == 0)) { + LOG_ERROR("colTypeIds is null or colNum is 0, colNum %d", colNum); + return nullptr; + } if (partitionMethod == "hash" || partitionMethod == "rr" || partitionMethod == "range") { return Create(colTypeIds, colNum, partitionNum, false, threadId); } else if (UNLIKELY(partitionMethod == "single")) { @@ -132,35 +151,38 @@ std::shared_ptr OckSplitter::Make(const std::string &partitionMetho } } -uint32_t OckSplitter::GetVarVecValue(VectorBatch &vb, uint32_t rowIndex, uint32_t colIndex, uint8_t **address) const +uint32_t OckSplitter::GetVarVecValue(VectorBatch &vb, uint32_t rowIndex, uint32_t colIndex) const { - auto vector = mIsSinglePt ? vb.GetVector(colIndex) : vb.GetVector(static_cast(colIndex + 1)); - if (vector->GetEncoding() == OMNI_VEC_ENCODING_DICTIONARY) { - return reinterpret_cast(vector)->GetVarchar(rowIndex, address); + auto vector = mIsSinglePt ? vb.Get(colIndex) : vb.Get(static_cast(colIndex + 1)); + if (vector->GetEncoding() == OMNI_DICTIONARY) { + auto vc = reinterpret_cast> *>(vector); + std::string_view value = vc->GetValue(rowIndex); + return static_cast(value.length()); } else { - return reinterpret_cast(vector)->GetValue(rowIndex, address); + auto vc = reinterpret_cast> *>(vector); + std::string_view value = vc->GetValue(rowIndex); + return static_cast(value.length()); } } uint32_t OckSplitter::GetRowLengthInBytes(VectorBatch &vb, uint32_t rowIndex) const { - uint8_t *address = nullptr; uint32_t length = mMinDataLenInVBByRow; // calculate variable width value for (auto &colIndex : mColIndexOfVarVec) { - length += GetVarVecValue(vb, rowIndex, colIndex, &address); + length += GetVarVecValue(vb, rowIndex, colIndex); } return length; } -bool OckSplitter::WriteNullValues(Vector *vector, std::vector &rowIndexes, uint32_t rowNum, uint8_t *&address) +bool OckSplitter::WriteNullValues(BaseVector *vector, std::vector &rowIndexes, uint32_t rowNum, uint8_t *&address) { uint8_t *nullAddress = address; for (uint32_t index = 0; index < rowNum; ++index) { - *nullAddress = const_cast((uint8_t *)(VectorHelper::GetNullsAddr(vector)))[rowIndexes[index]]; + *nullAddress = const_cast((uint8_t *)(unsafe::UnsafeBaseVector::GetNulls(vector)))[rowIndexes[index]]; nullAddress++; } @@ -169,34 +191,45 @@ bool OckSplitter::WriteNullValues(Vector *vector, std::vector &rowInde } template -bool OckSplitter::WriteFixedWidthValueTemple(Vector *vector, bool isDict, std::vector &rowIndexes, +bool OckSplitter::WriteFixedWidthValueTemple(BaseVector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, T *&address) { T *dstValues = address; T *srcValues = nullptr; if (isDict) { - auto ids = static_cast(mAllocator->alloc(mCurrentVB->GetRowCount() * sizeof(int32_t))); - if (UNLIKELY(ids == nullptr)) { - LOG_ERROR("Failed to allocate space for fixed width value ids."); + int32_t idsNum = mCurrentVB->GetRowCount(); + int64_t idsSizeInBytes = idsNum * sizeof(int32_t); + auto ids = VectorHelper::UnsafeGetValues(vector); + srcValues = reinterpret_cast(VectorHelper::UnsafeGetDictionary(vector)); + if (UNLIKELY(srcValues == nullptr)) { + LOG_ERROR("Source values address is null."); return false; } - auto dictionary = - (reinterpret_cast(vector))->ExtractDictionaryAndIds(0, mCurrentVB->GetRowCount(), ids); - if (UNLIKELY(dictionary == nullptr)) { - LOG_ERROR("Failed to get dictionary"); - return false; - } - srcValues = reinterpret_cast(VectorHelper::GetValuesAddr(dictionary)); for (uint32_t index = 0; index < rowNum; ++index) { - *dstValues++ = srcValues[reinterpret_cast(ids)[rowIndexes[index]]]; // write value to local blob + uint32_t idIndex = rowIndexes[index]; + if (UNLIKELY(idIndex >= idsNum)) { + LOG_ERROR("Invalid idIndex %d, idsNum.", idIndex, idsNum); + return false; + } + uint32_t rowIndex = reinterpret_cast(ids)[idIndex]; + *dstValues++ = srcValues[rowIndex]; // write value to local blob } - mAllocator->free((uint8_t *)(ids), mCurrentVB->GetRowCount() * sizeof(int32_t)); } else { - srcValues = reinterpret_cast(VectorHelper::GetValuesAddr(vector)); + srcValues = reinterpret_cast(VectorHelper::UnsafeGetValues(vector)); + if (UNLIKELY(srcValues == nullptr)) { + LOG_ERROR("Source values address is null."); + return false; + } + int32_t srcRowCount = vector->GetSize(); for (uint32_t index = 0; index < rowNum; ++index) { - *dstValues++ = srcValues[rowIndexes[index]]; // write value to local blob + uint32_t rowIndex = rowIndexes[index]; + if (UNLIKELY(rowIndex >= srcRowCount)) { + LOG_ERROR("Invalid rowIndex %d, srcRowCount %d.", rowIndex, srcRowCount); + return false; + } + *dstValues++ = srcValues[rowIndex]; // write value to local blob } } @@ -205,37 +238,45 @@ bool OckSplitter::WriteFixedWidthValueTemple(Vector *vector, bool isDict, std::v return true; } -bool OckSplitter::WriteDecimal128(Vector *vector, bool isDict, std::vector &rowIndexes, - uint32_t rowNum, uint64_t *&address) +bool OckSplitter::WriteDecimal128(BaseVector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, + uint64_t *&address) { uint64_t *dstValues = address; uint64_t *srcValues = nullptr; if (isDict) { - auto ids = static_cast(mAllocator->alloc(mCurrentVB->GetRowCount() * sizeof(int32_t))); - if (UNLIKELY(ids == nullptr)) { - LOG_ERROR("Failed to allocate space for fixed width value ids."); - return false; - } - - auto dictionary = - (reinterpret_cast(vector))->ExtractDictionaryAndIds(0, mCurrentVB->GetRowCount(), ids); - if (UNLIKELY(dictionary == nullptr)) { - LOG_ERROR("Failed to get dictionary"); + uint32_t idsNum = mCurrentVB->GetRowCount(); + auto ids = VectorHelper::UnsafeGetValues(vector); + srcValues = reinterpret_cast(VectorHelper::UnsafeGetDictionary(vector)); + if (UNLIKELY(srcValues == nullptr)) { + LOG_ERROR("Source values address is null."); return false; } - - srcValues = reinterpret_cast(VectorHelper::GetValuesAddr(dictionary)); for (uint32_t index = 0; index < rowNum; ++index) { - *dstValues++ = srcValues[reinterpret_cast(ids)[rowIndexes[index]] << 1]; - *dstValues++ = srcValues[(reinterpret_cast(ids)[rowIndexes[index]] << 1) | 1]; + uint32_t idIndex = rowIndexes[index]; + if (UNLIKELY(idIndex >= idsNum)) { + LOG_ERROR("Invalid idIndex %d, idsNum.", idIndex, idsNum); + return false; + } + uint32_t rowIndex = reinterpret_cast(ids)[idIndex]; + *dstValues++ = srcValues[rowIndex << 1]; + *dstValues++ = srcValues[rowIndex << 1 | 1]; } - mAllocator->free((uint8_t *)(ids), mCurrentVB->GetRowCount() * sizeof(int32_t)); } else { - srcValues = reinterpret_cast(VectorHelper::GetValuesAddr(vector)); + srcValues = reinterpret_cast(VectorHelper::UnsafeGetValues(vector)); + if (UNLIKELY(srcValues == nullptr)) { + LOG_ERROR("Source values address is null."); + return false; + } + int32_t srcRowCount = vector->GetSize(); for (uint32_t index = 0; index < rowNum; ++index) { + uint32_t rowIndex = rowIndexes[index]; + if (UNLIKELY(rowIndex >= srcRowCount)) { + LOG_ERROR("Invalid rowIndex %d, srcRowCount %d.", rowIndex, srcRowCount); + return false; + } *dstValues++ = srcValues[rowIndexes[index] << 1]; // write value to local blob - *dstValues++ = srcValues[(rowIndexes[index] << 1) | 1]; // write value to local blob + *dstValues++ = srcValues[rowIndexes[index] << 1 | 1]; // write value to local blob } } @@ -243,10 +284,10 @@ bool OckSplitter::WriteDecimal128(Vector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, uint8_t *&address) +bool OckSplitter::WriteFixedWidthValue(BaseVector *vector, ShuffleTypeId typeId, std::vector &rowIndexes, + uint32_t rowNum, uint8_t *&address) { - bool isDict = (vector->GetEncoding() == OMNI_VEC_ENCODING_DICTIONARY); + bool isDict = (vector->GetEncoding() == OMNI_DICTIONARY); switch (typeId) { case ShuffleTypeId::SHUFFLE_1BYTE: { WriteFixedWidthValueTemple(vector, isDict, rowIndexes, rowNum, address); @@ -285,21 +326,33 @@ bool OckSplitter::WriteFixedWidthValue(Vector *vector, ShuffleTypeId typeId, return true; } -bool OckSplitter::WriteVariableWidthValue(Vector *vector, std::vector &rowIndexes, - uint32_t rowNum, uint8_t *&address) +bool OckSplitter::WriteVariableWidthValue(BaseVector *vector, std::vector &rowIndexes, uint32_t rowNum, + uint8_t *&address) { - bool isDict = (vector->GetEncoding() == OMNI_VEC_ENCODING_DICTIONARY); + bool isDict = (vector->GetEncoding() == OMNI_DICTIONARY); auto *offsetAddress = reinterpret_cast(address); // point the offset space base address uint8_t *valueStartAddress = address + (rowNum + 1) * sizeof(int32_t); // skip the offsets space uint8_t *valueAddress = valueStartAddress; - int32_t length = 0; + uint32_t length = 0; uint8_t *srcValues = nullptr; + int32_t vectorSize = vector->GetSize(); for (uint32_t rowCnt = 0; rowCnt < rowNum; rowCnt++) { + uint32_t rowIndex = rowIndexes[rowCnt]; + if (UNLIKELY(rowIndex >= vectorSize)) { + LOG_ERROR("Invalid rowIndex %d, vectorSize %d.", rowIndex, vectorSize); + return false; + } if (isDict) { - length = reinterpret_cast(vector)->GetVarchar(rowIndexes[rowCnt], &srcValues); + auto vc = reinterpret_cast> *>(vector); + std::string_view value = vc->GetValue(rowIndex); + srcValues = reinterpret_cast(reinterpret_cast(value.data())); + length = static_cast(value.length()); } else { - length = reinterpret_cast(vector)->GetValue(rowIndexes[rowCnt], &srcValues); + auto vc = reinterpret_cast> *>(vector); + std::string_view value = vc->GetValue(rowIndex); + srcValues = reinterpret_cast(reinterpret_cast(value.data())); + length = static_cast(value.length()); } // write the null value in the vector with row index to local blob if (UNLIKELY(length > 0 && memcpy_s(valueAddress, length, srcValues, length) != EOK)) { @@ -320,7 +373,7 @@ bool OckSplitter::WriteVariableWidthValue(Vector *vector, std::vector bool OckSplitter::WriteOneVector(VectorBatch &vb, uint32_t colIndex, std::vector &rowIndexes, uint32_t rowNum, uint8_t **address) { - Vector *vector = vb.GetVector(colIndex); + BaseVector *vector = vb.Get(colIndex); if (UNLIKELY(vector == nullptr)) { LOG_ERROR("Failed to get vector with index %d in current vector batch", colIndex); return false; @@ -353,6 +406,10 @@ bool OckSplitter::WritePartVectorBatch(VectorBatch &vb, uint32_t partitionId) uint32_t regionId = 0; // backspace from local blob the region end address to remove preoccupied bytes for the vector batch region auto address = mOckBuffer->GetEndAddressOfRegion(partitionId, regionId, vbRegion->mLength); + if (UNLIKELY(address == nullptr)) { + LOG_ERROR("Failed to get address with partitionId %d", partitionId); + return false; + } // write the header information of the vector batch in local blob auto header = reinterpret_cast(address); header->length = vbRegion->mLength; @@ -361,6 +418,10 @@ bool OckSplitter::WritePartVectorBatch(VectorBatch &vb, uint32_t partitionId) if (!mOckBuffer->IsCompress()) { // record write bytes when don't need compress mTotalWriteBytes += header->length; } + if (UNLIKELY(partitionId > mPartitionLengths.size())) { + LOG_ERROR("Illegal partitionId %d", partitionId); + return false; + } mPartitionLengths[partitionId] += header->length; // we can't get real length when compress address += vbHeaderSize; // 8 means header length so skip @@ -382,6 +443,10 @@ bool OckSplitter::WritePartVectorBatch(VectorBatch &vb, uint32_t partitionId) bool OckSplitter::FlushAllRegionAndGetNewBlob(VectorBatch &vb) { + if (UNLIKELY(mPartitionNum > mCacheRegion.size())) { + LOG_ERROR("Illegal mPartitionNum %d", mPartitionNum); + return false; + } for (uint32_t partitionId = 0; partitionId < mPartitionNum; ++partitionId) { if (mCacheRegion[partitionId].mRowNum == 0) { continue; @@ -421,6 +486,10 @@ bool OckSplitter::FlushAllRegionAndGetNewBlob(VectorBatch &vb) bool OckSplitter::PreoccupiedBufferSpace(VectorBatch &vb, uint32_t partitionId, uint32_t rowIndex, uint32_t rowLength, bool newRegion) { + if (UNLIKELY(partitionId > mCacheRegion.size())) { + LOG_ERROR("Illegal partitionId %d", partitionId); + return false; + } uint32_t preoccupiedSize = rowLength; if (mCacheRegion[partitionId].mRowNum == 0) { preoccupiedSize += mMinDataLenInVB; // means create a new vector batch, so will cost header @@ -472,7 +541,7 @@ bool OckSplitter::Split(VectorBatch &vb) ResetCacheRegion(); // clear the record about those partition regions in old vector batch mCurrentVB = &vb; // point to current native vector batch address // the first vector in vector batch that record partitionId about same index row when exist multiple partition - mPtViewInCurVB = mIsSinglePt ? nullptr : reinterpret_cast(vb.GetVector(0)); + mPtViewInCurVB = mIsSinglePt ? nullptr : reinterpret_cast *>(vb.Get(0)); // PROFILE_START_L1(PREOCCUPIED_STAGE) for (int rowIndex = 0; rowIndex < vb.GetRowCount(); ++rowIndex) { @@ -499,7 +568,7 @@ bool OckSplitter::Split(VectorBatch &vb) } // release data belong to the vector batch in memory after write it to local blob - vb.ReleaseAllVectors(); + vb.FreeAllVectors(); // PROFILE_END_L1(RELEASE_VECTOR) mCurrentVB = nullptr; diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_splitter.h similarity index 86% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_splitter.h index fc81195099f49a2cae202a10324f0725ee5a08bb..6118289b7cf521aeff9d862884d96375cd2e9412 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.h +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_splitter.h @@ -20,8 +20,6 @@ #include "vec_data.pb.h" #include "ock_hash_write_buffer.h" -#include "memory/base_allocator.h" - using namespace spark; using namespace omniruntime::vec; using namespace omniruntime::type; @@ -70,7 +68,10 @@ public: return false; } - InitCacheRegion(); + if (UNLIKELY(!InitCacheRegion())) { + LOG_ERROR("Failed to initialize CacheRegion"); + return false; + } return true; } @@ -89,7 +90,7 @@ private: bool isSinglePt, uint64_t threadId); bool ToSplitterTypeId(const int32_t *vBColTypes); - uint32_t GetVarVecValue(VectorBatch &vb, uint32_t rowIndex, uint32_t colIndex, uint8_t **address) const; + uint32_t GetVarVecValue(VectorBatch &vb, uint32_t rowIndex, uint32_t colIndex) const; uint32_t GetRowLengthInBytes(VectorBatch &vb, uint32_t rowIndex) const; inline uint32_t GetPartitionIdOfRow(uint32_t rowIndex) @@ -98,7 +99,12 @@ private: return mIsSinglePt ? 0 : mPtViewInCurVB->GetValue(rowIndex); } - void InitCacheRegion(); + void CastOmniToShuffleType(DataTypeId omniType, ShuffleTypeId shuffleType, uint32_t size) + { + mVBColShuffleTypes.emplace_back(shuffleType); + mMinDataLenInVBByRow += size; + } + bool InitCacheRegion(); inline void ResetCacheRegion() { @@ -137,21 +143,19 @@ private: bool newRegion); bool WritePartVectorBatch(VectorBatch &vb, uint32_t partitionId); - static bool WriteNullValues(Vector *vector, std::vector &rowIndexes, uint32_t rowNum, uint8_t *&address); + static bool WriteNullValues(BaseVector *vector, std::vector &rowIndexes, uint32_t rowNum, uint8_t *&address); template - bool WriteFixedWidthValueTemple(Vector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, + bool WriteFixedWidthValueTemple(BaseVector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, T *&address); - bool WriteDecimal128(Vector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, uint64_t *&address); - bool WriteFixedWidthValue(Vector *vector, ShuffleTypeId typeId, std::vector &rowIndexes, + bool WriteDecimal128(BaseVector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, uint64_t *&address); + bool WriteFixedWidthValue(BaseVector *vector, ShuffleTypeId typeId, std::vector &rowIndexes, uint32_t rowNum, uint8_t *&address); - static bool WriteVariableWidthValue(Vector *vector, std::vector &rowIndexes, uint32_t rowNum, + static bool WriteVariableWidthValue(BaseVector *vector, std::vector &rowIndexes, uint32_t rowNum, uint8_t *&address); bool WriteOneVector(VectorBatch &vb, uint32_t colIndex, std::vector &rowIndexes, uint32_t rowNum, uint8_t **address); private: - BaseAllocator *mAllocator = omniruntime::mem::GetProcessRootAllocator(); - static constexpr uint32_t vbDataHeadLen = 8; // Byte static constexpr uint32_t uint8Size = 1; static constexpr uint32_t uint16Size = 2; @@ -159,6 +163,7 @@ private: static constexpr uint32_t uint64Size = 8; static constexpr uint32_t decimal128Size = 16; static constexpr uint32_t vbHeaderSize = 8; + static constexpr uint32_t doubleNum = 2; /* the region use for all vector batch ---------------------------------------------------------------- */ // this splitter which corresponding to one map task in one shuffle, so some params is same uint32_t mPartitionNum = 0; @@ -187,7 +192,7 @@ private: std::vector mCacheRegion {}; // the vector point to vector0 in current vb which record rowIndex -> ptId - IntVector *mPtViewInCurVB = nullptr; + Vector *mPtViewInCurVB = nullptr; /* ock shuffle resource -------------------------------------------------------------------------------- */ OckHashWriteBuffer *mOckBuffer = nullptr; diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_type.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_type.h similarity index 33% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_type.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_type.h index e07e67f17d7281f5df0e1d4ee17a4949bc1da697..03e444b6ce4e7284a36e859c327cc51546fb26ab 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_type.h +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_type.h @@ -6,7 +6,7 @@ #define SPARK_THESTRAL_PLUGIN_OCK_TYPE_H #include "ock_vector.h" -#include "common/debug.h" +#include "common/common.h" namespace ock { namespace dopspark { @@ -33,58 +33,118 @@ enum class ShuffleTypeId : int { using VBHeaderPtr = struct VBDataHeaderDesc { uint32_t length = 0; // 4Byte uint32_t rowNum = 0; // 4Byte -} __attribute__((packed)) * ; +} __attribute__((packed)) *; -using VBDataDescPtr = struct VBDataDesc { - explicit VBDataDesc(uint32_t colNum) +class VBDataDesc { +public: + VBDataDesc() = default; + ~VBDataDesc() { + for (auto &vector : mColumnsHead) { + if (vector == nullptr) { + continue; + } + auto currVector = vector; + while (currVector->GetNextVector() != nullptr) { + auto nextVector = currVector->GetNextVector(); + currVector->SetNextVector(nullptr); + currVector = nextVector; + } + } + } + + bool Initialize(uint32_t colNum) + { + this->colNum = colNum; mHeader.rowNum = 0; mHeader.length = 0; - mColumnsHead.reserve(colNum); mColumnsHead.resize(colNum); - mColumnsCur.reserve(colNum); mColumnsCur.resize(colNum); - mVectorValueLength.reserve(colNum); - mVectorValueLength.resize(colNum); + mColumnsCapacity.resize(colNum); - for (auto &index : mColumnsHead) { - index = new (std::nothrow) OckVector(); + for (auto &vector : mColumnsHead) { + vector = std::make_shared(); + if (vector == nullptr) { + mColumnsHead.clear(); + return false; + } } + return true; } inline void Reset() { mHeader.rowNum = 0; mHeader.length = 0; - std::fill(mVectorValueLength.begin(), mVectorValueLength.end(), 0); + std::fill(mColumnsCapacity.begin(), mColumnsCapacity.end(), 0); for (uint32_t index = 0; index < mColumnsCur.size(); ++index) { mColumnsCur[index] = mColumnsHead[index]; } } + std::shared_ptr GetColumnHead(uint32_t colIndex) { + if (colIndex >= colNum) { + return nullptr; + } + return mColumnsHead[colIndex]; + } + + void SetColumnCapacity(uint32_t colIndex, uint32_t length) { + mColumnsCapacity[colIndex] = length; + } + + uint32_t GetColumnCapacity(uint32_t colIndex) { + return mColumnsCapacity[colIndex]; + } + + std::shared_ptr GetCurColumn(uint32_t colIndex) + { + if (colIndex >= colNum) { + return nullptr; + } + auto currVector = mColumnsCur[colIndex]; + if (currVector->GetNextVector() == nullptr) { + auto newCurVector = std::make_shared(); + if (UNLIKELY(newCurVector == nullptr)) { + LOG_ERROR("Failed to new instance for ock vector"); + return nullptr; + } + currVector->SetNextVector(newCurVector); + mColumnsCur[colIndex] = newCurVector; + } else { + mColumnsCur[colIndex] = currVector->GetNextVector(); + } + return currVector; + } + + uint32_t GetTotalCapacity() + { + return mHeader.length; + } + + uint32_t GetTotalRowNum() + { + return mHeader.rowNum; + } + + void AddTotalCapacity(uint32_t length) { + mHeader.length += length; + } + + void AddTotalRowNum(uint32_t rowNum) + { + mHeader.rowNum +=rowNum; + } + +private: + uint32_t colNum = 0; VBDataHeaderDesc mHeader; - std::vector mVectorValueLength; - std::vector mColumnsCur; - std::vector mColumnsHead; // Array[List[OckVector *]] -} * ; + std::vector mColumnsCapacity; + std::vector mColumnsCur; + std::vector mColumnsHead; // Array[List[OckVector *]] +}; +using VBDataDescPtr = std::shared_ptr; } } -#define PROFILE_START_L1(name) \ - long tcDiff##name = 0; \ - struct timespec tcStart##name = { 0, 0 }; \ - clock_gettime(CLOCK_MONOTONIC, &tcStart##name); - -#define PROFILE_END_L1(name) \ - struct timespec tcEnd##name = { 0, 0 }; \ - clock_gettime(CLOCK_MONOTONIC, &tcEnd##name); \ - \ - long diffSec##name = tcEnd##name.tv_sec - tcStart##name.tv_sec; \ - if (diffSec##name == 0) { \ - tcDiff##name = tcEnd##name.tv_nsec - tcStart##name.tv_nsec; \ - } else { \ - tcDiff##name = diffSec##name * 1000000000 + tcEnd##name.tv_nsec - tcStart##name.tv_nsec; \ - } - -#define PROFILE_VALUE(name) tcDiff##name #endif // SPARK_THESTRAL_PLUGIN_OCK_TYPE_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_vector.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_vector.h similarity index 88% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_vector.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_vector.h index 0cfca5d63173c04c37771900e1ac17c2c04e8bba..515f88db8355a58321a7290179e48b48802cb8cc 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_vector.h +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_vector.h @@ -69,12 +69,12 @@ public: valueOffsetsAddress = address; } - inline void SetNextVector(OckVector *next) + inline void SetNextVector(std::shared_ptr next) { mNext = next; } - inline OckVector *GetNextVector() + inline std::shared_ptr GetNextVector() { return mNext; } @@ -87,8 +87,9 @@ private: void *valueNullsAddress = nullptr; void *valueOffsetsAddress = nullptr; - OckVector *mNext = nullptr; + std::shared_ptr mNext = nullptr; }; +using OckVectorPtr = std::shared_ptr; } } #endif // SPARK_THESTRAL_PLUGIN_OCK_VECTOR_H diff --git a/omnioperator/omniop-spark-extension-ock/cpp/test/CMakeLists.txt b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/CMakeLists.txt similarity index 95% rename from omnioperator/omniop-spark-extension-ock/cpp/test/CMakeLists.txt rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/CMakeLists.txt index 53605f08556f538682e83427a130c1684318702f..dedb097bb17e65b3a6e42a602be15423c99e9652 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/test/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/CMakeLists.txt @@ -28,7 +28,7 @@ target_link_libraries(${TP_TEST_TARGET} pthread stdc++ dl - boostkit-omniop-vector-1.1.0-aarch64 + boostkit-omniop-vector-1.2.0-aarch64 securec ock_columnar_shuffle) diff --git a/omnioperator/omniop-spark-extension-ock/cpp/test/shuffle/CMakeLists.txt b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/shuffle/CMakeLists.txt similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/test/shuffle/CMakeLists.txt rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/shuffle/CMakeLists.txt diff --git a/omnioperator/omniop-spark-extension-ock/cpp/test/shuffle/ock_shuffle_test.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/shuffle/ock_shuffle_test.cpp similarity index 91% rename from omnioperator/omniop-spark-extension-ock/cpp/test/shuffle/ock_shuffle_test.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/shuffle/ock_shuffle_test.cpp index 7980cbf198d192488c313fd719f340fc71c0521a..cc02862fd1b91b1117bdfb07346af13d27db5259 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/test/shuffle/ock_shuffle_test.cpp +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/shuffle/ock_shuffle_test.cpp @@ -54,7 +54,7 @@ bool PrintVectorBatch(uint8_t **startAddress, uint32_t &length) info << "vector_batch: { "; for (uint32_t colIndex = 0; colIndex < gColNum; colIndex++) { auto typeId = static_cast(gVecTypeIds[colIndex]); - Vector *vector = OckNewbuildVector(typeId, rowNum); + BaseVector *vector = OckNewbuildVector(typeId, rowNum); if (typeId == OMNI_VARCHAR) { uint32_t varlength = 0; instance->CalVectorValueLength(colIndex, varlength); @@ -75,29 +75,29 @@ bool PrintVectorBatch(uint8_t **startAddress, uint32_t &length) for (uint32_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { LOG_DEBUG("%d", const_cast((uint8_t*)(VectorHelper::GetNullsAddr(vector)))[rowIndex]); info << "{ rowIndex: " << rowIndex << ", nulls: " << - std::to_string(const_cast((uint8_t*)(VectorHelper::GetNullsAddr(vector)))[rowIndex]); + std::to_string(const_cast((uint8_t*)(omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vector)))[rowIndex]); switch (typeId) { case OMNI_SHORT: - info << ", value: " << static_cast(vector)->GetValue(rowIndex) << " }, "; + info << ", value: " << static_cast *>(vector)->GetValue(rowIndex) << " }, "; break; case OMNI_INT: { - info << ", value: " << static_cast(vector)->GetValue(rowIndex) << " }, "; + info << ", value: " << static_cast *>(vector)->GetValue(rowIndex) << " }, "; break; } case OMNI_LONG: { - info << ", value: " << static_cast(vector)->GetValue(rowIndex) << " }, "; + info << ", value: " << static_cast *>(vector)->GetValue(rowIndex) << " }, "; break; } case OMNI_DOUBLE: { - info << ", value: " << static_cast(vector)->GetValue(rowIndex) << " }, "; + info << ", value: " << static_cast *>(vector)->GetValue(rowIndex) << " }, "; break; } case OMNI_DECIMAL64: { - info << ", value: " << static_cast(vector)->GetValue(rowIndex) << " }, "; + info << ", value: " << static_cast *>(vector)->GetValue(rowIndex) << " }, "; break; } case OMNI_DECIMAL128: { - info << ", value: " << static_cast(vector)->GetValue(rowIndex) << " }, "; + info << ", value: " << static_cast *>(vector)->GetValue(rowIndex) << " }, "; break; } case OMNI_VARCHAR: { // unknown length for value vector, calculate later @@ -118,9 +118,16 @@ bool PrintVectorBatch(uint8_t **startAddress, uint32_t &length) valueAddress += vector->GetValueOffset(rowIndex); }*/ uint8_t *valueAddress = nullptr; - int32_t length = static_cast(vector)->GetValue(rowIndex, &valueAddress); + int32_t length = reinterpret_cast> *>(vector); std::string valueString(valueAddress, valueAddress + length); - info << ", value: " << valueString << " }, "; + uint32_t length = 0; + std::string_view value; + if (!vc->IsNull(rowIndex)) { + value = vc->GetValue(); + valueAddress = reinterpret_cast(reinterpret_cast(value.data())); + length = static_cast(value.length()); + } + info << ", value: " << value << " }, "; break; } default: @@ -314,7 +321,7 @@ TEST_F(OckShuffleTest, Split_Fixed_Long_Cols) sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]), false, 40960, 41943040, 134217728); gTempSplitId = splitterId; // very important // for (uint64_t j = 0; j < 999; j++) { - VectorBatch *vb = OckCreateVectorBatch_1fixedCols_withPid(partitionNum, 10000); + VectorBatch *vb = OckCreateVectorBatch_1fixedCols_withPid(partitionNum, 10000, LongType()); OckTest_splitter_split(splitterId, vb); // } OckTest_splitter_stop(splitterId); @@ -323,7 +330,7 @@ TEST_F(OckShuffleTest, Split_Fixed_Long_Cols) TEST_F(OckShuffleTest, Split_Fixed_Cols) { - int32_t inputVecTypeIds[] = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE}; // 4Byte + 8Byte + 8Byte + 3Byte + int32_t inputVecTypeIds[] = {OMNI_BOOLEAN, OMNI_SHORT, OMNI_INT, OMNI_LONG, OMNI_DOUBLE}; // 4Byte + 8Byte + 8Byte + 3Byte gVecTypeIds = &inputVecTypeIds[0]; gColNum = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); int partitionNum = 4; @@ -331,7 +338,7 @@ TEST_F(OckShuffleTest, Split_Fixed_Cols) sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]), false, 40960, 41943040, 134217728); gTempSplitId = splitterId; // very important // for (uint64_t j = 0; j < 999; j++) { - VectorBatch *vb = OckCreateVectorBatch_3fixedCols_withPid(partitionNum, 999); + VectorBatch *vb = OckCreateVectorBatch_5fixedCols_withPid(partitionNum, 999); OckTest_splitter_split(splitterId, vb); // } OckTest_splitter_stop(splitterId); @@ -340,7 +347,7 @@ TEST_F(OckShuffleTest, Split_Fixed_Cols) TEST_F(OckShuffleTest, Split_Fixed_SinglePartition_SomeNullRow) { - int32_t inputVecTypeIds[] = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR}; // 4 + 8 + 8 + 4 + 4 + int32_t inputVecTypeIds[] = {OMNI_BOOLEAN, OMNI_SHORT, OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR}; // 4 + 8 + 8 + 4 + 4 gVecTypeIds = &inputVecTypeIds[0]; gColNum = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); int partitionNum = 1; @@ -399,7 +406,7 @@ TEST_F(OckShuffleTest, Split_Long_10WRows) sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]), false, 40960, 41943040, 134217728); gTempSplitId = splitterId; // very important for (uint64_t j = 0; j < 100; j++) { - VectorBatch *vb = OckCreateVectorBatch_1longCol_withPid(partitionNum, 10000); + VectorBatch *vb = OckCreateVectorBatch_1fixedCols_withPid(partitionNum, 10000, LongType()); OckTest_splitter_split(splitterId, vb); } OckTest_splitter_stop(splitterId); @@ -458,7 +465,7 @@ TEST_F(OckShuffleTest, Split_VarChar_First) TEST_F(OckShuffleTest, Split_Dictionary) { - int32_t inputVecTypeIds[] = {OMNI_INT, OMNI_LONG, OMNI_DECIMAL64, OMNI_DECIMAL128}; + int32_t inputVecTypeIds[] = {OMNI_INT, OMNI_LONG}; int partitionNum = 4; gVecTypeIds = &inputVecTypeIds[0]; gColNum = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); @@ -483,7 +490,7 @@ TEST_F(OckShuffleTest, Split_OMNI_DECIMAL128) sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]), false, 40960, 41943040, 134217728); gTempSplitId = splitterId; // very important for (uint64_t j = 0; j < 2; j++) { - VectorBatch *vb = OckCreateVectorBatch_1decimal128Col_withPid(partitionNum); + VectorBatch *vb = OckCreateVectorBatch_1decimal128Col_withPid(partitionNum, 999); OckTest_splitter_split(splitterId, vb); } OckTest_splitter_stop(splitterId); diff --git a/omnioperator/omniop-spark-extension-ock/cpp/test/tptest.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/tptest.cpp similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/test/tptest.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/tptest.cpp diff --git a/omnioperator/omniop-spark-extension-ock/cpp/test/utils/CMakeLists.txt b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/CMakeLists.txt similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/test/utils/CMakeLists.txt rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/CMakeLists.txt diff --git a/omnioperator/omniop-spark-extension-ock/cpp/test/utils/ock_test_utils.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/ock_test_utils.cpp similarity index 39% rename from omnioperator/omniop-spark-extension-ock/cpp/test/utils/ock_test_utils.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/ock_test_utils.cpp index 2b49ba28ffaaf79621de049fb59e39120cad5490..251aea490f144d386820074bf9101c92594022fc 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/test/utils/ock_test_utils.cpp +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/ock_test_utils.cpp @@ -10,7 +10,7 @@ using namespace omniruntime::vec; using namespace omniruntime::type; -void OckToVectorTypes(const int32_t *dataTypeIds, int32_t dataTypeCount, std::vector &dataTypes) +/*void OckToVectorTypes(const int32_t *dataTypeIds, int32_t dataTypeCount, std::vector &dataTypes) { for (int i = 0; i < dataTypeCount; ++i) { if (dataTypeIds[i] == OMNI_VARCHAR) { @@ -22,125 +22,39 @@ void OckToVectorTypes(const int32_t *dataTypeIds, int32_t dataTypeCount, std::ve } dataTypes.emplace_back(DataType(dataTypeIds[i])); } -} +}*/ -VectorBatch *OckCreateInputData(const int32_t numRows, const int32_t numCols, int32_t *inputTypeIds, int64_t *allData) +VectorBatch *OckCreateInputData(const DataType &types, int32_t rowCount, ...) { - auto *vecBatch = new VectorBatch(numCols, numRows); - std::vector inputTypes; - OckToVectorTypes(inputTypeIds, numCols, inputTypes); - vecBatch->NewVectors(VectorAllocator::GetGlobalAllocator(), inputTypes); - for (int i = 0; i < numCols; ++i) { - switch (inputTypeIds[i]) { - case OMNI_INT: - ((IntVector *)vecBatch->GetVector(i))->SetValues(0, (int32_t *)allData[i], numRows); - break; - case OMNI_LONG: - ((LongVector *)vecBatch->GetVector(i))->SetValues(0, (int64_t *)allData[i], numRows); - break; - case OMNI_DOUBLE: - ((DoubleVector *)vecBatch->GetVector(i))->SetValues(0, (double *)allData[i], numRows); - break; - case OMNI_SHORT: - ((IntVector *)vecBatch->GetVector(i))->SetValues(0, (int32_t *)allData[i], numRows); - break; - case OMNI_VARCHAR: - case OMNI_CHAR: { - for (int j = 0; j < numRows; ++j) { - int64_t addr = (reinterpret_cast(allData[i]))[j]; - std::string s(reinterpret_cast(addr)); - ((VarcharVector *)vecBatch->GetVector(i))->SetValue(j, (uint8_t *)(s.c_str()), s.length()); - } - break; - } - case OMNI_DECIMAL128: - ((Decimal128Vector *)vecBatch->GetVector(i))->SetValues(0, (int64_t *)allData[i], numRows); - break; - default: { - LogError("No such data type %d", inputTypeIds[i]); - } - } + int32_t typesCount = types.GetSize(); + auto *vecBatch = new VectorBatch(rowCount); + va_list args; + va_start(args, rowCount); + for (int32_t i = 0; i< typesCount; i++) { + dataTypePtr = type = types.GetType(i); + VectorBatch->Append(CreateVector(*type, rowCount, args)); } + va_end(args); return vecBatch; } -VarcharVector *OckCreateVarcharVector(VarcharDataType type, std::string *values, int32_t length) +BaseVector *CreateVector(DataType &dataType, int32_t rowCount, va_list &args) { - VectorAllocator *vecAllocator = VectorAllocator::GetGlobalAllocator(); - uint32_t width = type.GetWidth(); - VarcharVector *vector = std::make_unique(vecAllocator, length * width, length).release(); - uint32_t offset = 0; - for (int32_t i = 0; i < length; i++) { - vector->SetValue(i, reinterpret_cast(values[i].c_str()), values[i].length()); - bool isNull = values[i].empty() ? true : false; - vector->SetValueNull(i, isNull); - vector->SetValueOffset(i, offset); - offset += values[i].length(); - } - - if (length > 0) { - vector->SetValueOffset(values->size(), offset); - } - - std::stringstream offsetValue; - offsetValue << "{ "; - for (uint32_t index = 0; index < length; index++) { - offsetValue << vector->GetValueOffset(index) << ", "; - } - - offsetValue << vector->GetValueOffset(values->size()) << " }"; - - LOG_INFO("%s", offsetValue.str().c_str()); - - return vector; -} - -Decimal128Vector *OckCreateDecimal128Vector(Decimal128 *values, int32_t length) -{ - VectorAllocator *vecAllocator = VectorAllocator::GetGlobalAllocator(); - Decimal128Vector *vector = std::make_unique(vecAllocator, length).release(); - for (int32_t i = 0; i < length; i++) { - vector->SetValue(i, values[i]); - } - return vector; + return DYNAMIC_TYPE_DISPATCH(CreateFlatVector, dataType.GetId(), rowCount, args); } -Vector *OckCreateVector(DataType &vecType, int32_t rowCount, va_list &args) -{ - switch (vecType.GetId()) { - case OMNI_INT: - case OMNI_DATE32: - return OckCreateVector(va_arg(args, int32_t *), rowCount); - case OMNI_LONG: - case OMNI_DECIMAL64: - return OckCreateVector(va_arg(args, int64_t *), rowCount); - case OMNI_DOUBLE: - return OckCreateVector(va_arg(args, double *), rowCount); - case OMNI_BOOLEAN: - return OckCreateVector(va_arg(args, bool *), rowCount); - case OMNI_VARCHAR: - case OMNI_CHAR: - return OckCreateVarcharVector(static_cast(vecType), va_arg(args, std::string *), - rowCount); - case OMNI_DECIMAL128: - return OckCreateDecimal128Vector(va_arg(args, Decimal128 *), rowCount); - default: - std::cerr << "Unsupported type : " << vecType.GetId() << std::endl; - return nullptr; - } -} -DictionaryVector *OckCreateDictionaryVector(DataType &vecType, int32_t rowCount, int32_t *ids, int32_t idsCount, ...) +BaseVector *CreateDictionaryVector(DataType &dataType, int32_t rowCount, int32_t *ids, int32_t idsCount, + ..) { va_list args; va_start(args, idsCount); - Vector *dictionary = OckCreateVector(vecType, rowCount, args); + BaseVector *dictionary = CreateVector(dataType, rowCount, args); va_end(args); - auto vec = std::make_unique(dictionary, ids, idsCount).release(); - delete dictionary; - return vec; + return DYNAMIC_TYPE_DISPATCH(CreateDictionary, dataType.GetId(), dictionary, ids, idsCount); } +/* Vector *OckbuildVector(const DataType &aggType, int32_t rowNumber) { VectorAllocator *vecAllocator = VectorAllocator::GetGlobalAllocator(); @@ -212,47 +126,37 @@ Vector *OckbuildVector(const DataType &aggType, int32_t rowNumber) return nullptr; } } -} +}*/ -Vector *OckNewbuildVector(const DataTypeId &typeId, int32_t rowNumber) +BaseVector *OckNewbuildVector(const DataTypeId &typeId, int32_t rowNumber) { - VectorAllocator *vecAllocator = VectorAllocator::GetGlobalAllocator(); - switch (typeId) { + switch (typeId) { case OMNI_SHORT: { - auto *col = new ShortVector(vecAllocator, rowNumber); - return col; + return new Vector(rowNumber); } case OMNI_NONE: { - auto *col = new LongVector(vecAllocator, rowNumber); - return col; + return new Vector(rowNumber); } case OMNI_INT: case OMNI_DATE32: { - auto *col = new IntVector(vecAllocator, rowNumber); - return col; + return new Vector(rowNumber); } case OMNI_LONG: case OMNI_DECIMAL64: { - auto *col = new LongVector(vecAllocator, rowNumber); - return col; + return new Vector(rowNumber); } case OMNI_DOUBLE: { - auto *col = new DoubleVector(vecAllocator, rowNumber); - return col; + return new Vector(rowNumber); } case OMNI_BOOLEAN: { - auto *col = new BooleanVector(vecAllocator, rowNumber); - return col; + return new Vector(rowNumber); } case OMNI_DECIMAL128: { - auto *col = new Decimal128Vector(vecAllocator, rowNumber); - return col; + return new Vector(rowNumber); } case OMNI_VARCHAR: case OMNI_CHAR: { - VarcharDataType charType = (VarcharDataType &)typeId; - auto *col = new VarcharVector(vecAllocator, charType.GetWidth() * rowNumber, rowNumber); - return col; + return new Vector>(rowNumber); } default: { LogError("No such %d type support", typeId); @@ -261,15 +165,15 @@ Vector *OckNewbuildVector(const DataTypeId &typeId, int32_t rowNumber) } } -VectorBatch *OckCreateVectorBatch(DataTypes &types, int32_t rowCount, ...) +VectorBatch *OckCreateVectorBatch(const DataTypes &types, int32_t rowCount, ...) { int32_t typesCount = types.GetSize(); - VectorBatch *vectorBatch = std::make_unique(typesCount).release(); + auto *vectorBatch = new vecBatch(rowCount); va_list args; va_start(args, rowCount); for (int32_t i = 0; i < typesCount; i++) { - DataType type = types.Get()[i]; - vectorBatch->SetVector(i, OckCreateVector(type, rowCount, args)); + dataTypePtr type = types.GetType(i); + vectorBatch->Append(OckCreateVector(*type, rowCount, args)); } va_end(args); return vectorBatch; @@ -286,23 +190,46 @@ VectorBatch *OckCreateVectorBatch_1row_varchar_withPid(int pid, const std::strin { // gen vectorBatch const int32_t numCols = 2; - auto inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_VARCHAR; - + DataTypes inputTypes(std::vector)({ IntType(), VarcharType()}); const int32_t numRows = 1; auto *col1 = new int32_t[numRows]; col1[0] = pid; - auto *col2 = new int64_t[numRows]; - auto *strTmp = new std::string(std::move(inputString)); - col2[0] = (int64_t)(strTmp->c_str()); + auto *col2 = new std::string[numRows]; + col2[0] = std::move(inputString); + VectorBatch *in = OckCreateInputData(inputTypes, numCols, col1, col2); + delete[] col1; + delete[] col2; + return in; +} - int64_t allData[numCols] = {reinterpret_cast(col1), - reinterpret_cast(col2)}; - VectorBatch *in = OckCreateInputData(numRows, numCols, inputTypes, allData); +VectorBatch *OckCreateVectorBatch_4varcharCols_withPid(int parNum, int rowNum) +{ + int partitionNum = parNum; + const int32_t numCols = 5; + DataTypes inputTypes(std::vector)({ IntType(), VarcharType(), VarcharType(), VarcharType(), VarcharType() }); + const int32_t numRows = rowNum; + auto *col0 = new int32_t[numRows]; + auto *col1 = new std::string[numRows]; + auto *col2 = new std::string[numRows]; + auto *col3 = new std::string[numRows]; + auto *col4 = new std::string[numRows]; + col0[i] = (i + 1) % partitionNum; + std::string strTmp1 = std::string("Col1_START_" + to_string(i + 1) + "_END_"); + col1[i] = std::move(strTmp1); + std::string strTmp2 = std::string("Col2_START_" + to_string(i + 1) + "_END_"); + col2[i] = std::move(strTmp2); + std::string strTmp3 = std::string("Col3_START_" + to_string(i + 1) + "_END_"); + col3[i] = std::move(strTmp3); + std::string strTmp4 = std::string("Col4_START_" + to_string(i + 1) + "_END_"); + col4[i] = std::move(strTmp4); + } + + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2, col3, col4); + delete[] col0; delete[] col1; delete[] col2; - delete strTmp; + delete[] col3; + delete[] col4; return in; } @@ -316,229 +243,104 @@ VectorBatch *OckCreateVectorBatch_1row_varchar_withPid(int pid, const std::strin VectorBatch *OckCreateVectorBatch_4col_withPid(int parNum, int rowNum) { int partitionNum = parNum; - const int32_t numCols = 6; - auto *inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_INT; - inputTypes[2] = OMNI_LONG; - inputTypes[3] = OMNI_DOUBLE; - inputTypes[4] = OMNI_VARCHAR; - inputTypes[5] = OMNI_SHORT; - + DataTypes inputTypes(std::vector)({ IntType(), VarcharType(), VarcharType(), VarcharType(), VarcharType() }); + const int32_t numRows = rowNum; auto *col0 = new int32_t[numRows]; auto *col1 = new int32_t[numRows]; auto *col2 = new int64_t[numRows]; auto *col3 = new double[numRows]; - auto *col4 = new int64_t[numRows]; - auto *col5 = new int16_t[numRows]; + auto *col4 = new std::string[numRows]; std::string startStr = "_START_"; std::string endStr = "_END_"; - - std::vector string_cache_test_; + std::vector string_cache_test_; for (int i = 0; i < numRows; i++) { col0[i] = (i + 1) % partitionNum; col1[i] = i + 1; col2[i] = i + 1; col3[i] = i + 1; - auto *strTmp = new std::string(startStr + std::to_string(i + 1) + endStr); - string_cache_test_.push_back(strTmp); - col4[i] = (int64_t)((*strTmp).c_str()); - col5[i] = i + 1; + std::string strTmp = std::string(startStr + to_string(i + 1) + endStr); + col4[i] = std::move(strTmp); } - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2), - reinterpret_cast(col3), - reinterpret_cast(col4), - reinterpret_cast(col5)}; - VectorBatch *in = OckCreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2, col3, col4); delete[] col0; delete[] col1; delete[] col2; delete[] col3; delete[] col4; - - for (int p = 0; p < string_cache_test_.size(); p++) { - delete string_cache_test_[p]; // 释放内存 - } return in; } -VectorBatch *OckCreateVectorBatch_1longCol_withPid(int parNum, int rowNum) -{ - int partitionNum = parNum; - const int32_t numCols = 2; - auto *inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_LONG; - - const int32_t numRows = rowNum; - auto *col0 = new int32_t[numRows]; - auto *col1 = new int64_t[numRows]; - for (int i = 0; i < numRows; i++) { - col0[i] = (i + 1) % partitionNum; - col1[i] = i + 1; - } - - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1)}; - VectorBatch *in = OckCreateInputData(numRows, numCols, inputTypes, allData); - for (int i = 0; i < 2; i++) { - delete (int64_t *)allData[i]; // 释放内存 - } - return in; -} - -VectorBatch *OckCreateVectorBatch_2column_1row_withPid(int pid, std::string strVar, int intVar) -{ - const int32_t numCols = 3; - auto *inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_VARCHAR; - inputTypes[2] = OMNI_INT; +VectorBatch* CreateVectorBatch_2column_1row_withPid(int pid, std::string strVar, int intVar) { + DataTypes inputTypes(std::vector({ IntType(), VarcharType(), IntType() })); const int32_t numRows = 1; - auto *col0 = new int32_t[numRows]; - auto *col1 = new int64_t[numRows]; - auto *col2 = new int32_t[numRows]; + auto* col0 = new int32_t[numRows]; + auto* col1 = new std::string[numRows]; + auto* col2 = new int32_t[numRows]; col0[0] = pid; - auto *strTmp = new std::string(strVar); - col1[0] = (int64_t)(strTmp->c_str()); + col1[0] = std::move(strVar); col2[0] = intVar; - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2)}; - VectorBatch *in = OckCreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2); delete[] col0; delete[] col1; delete[] col2; - delete strTmp; return in; } -VectorBatch *OckCreateVectorBatch_4varcharCols_withPid(int parNum, int rowNum) +VectorBatch *OckCreateVectorBatch_1fixedCols_withPid(int parNum, int rowNum, dataTypePtr fixColType) { int partitionNum = parNum; - const int32_t numCols = 5; - auto *inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_VARCHAR; - inputTypes[2] = OMNI_VARCHAR; - inputTypes[3] = OMNI_VARCHAR; - inputTypes[4] = OMNI_VARCHAR; + DataTypes inputTypes(std::vector({ IntType(), std::move(fixColType) })); const int32_t numRows = rowNum; - auto *col0 = new int32_t[numRows]; - auto *col1 = new int64_t[numRows]; - auto *col2 = new int64_t[numRows]; - auto *col3 = new int64_t[numRows]; - auto *col4 = new int64_t[numRows]; - - std::vector string_cache_test_; + auto* col0 = new int32_t[numRows]; + auto* col1 = new int64_t[numRows]; for (int i = 0; i < numRows; i++) { col0[i] = (i + 1) % partitionNum; - auto *strTmp1 = new std::string("Col1_START_" + std::to_string(i + 1) + "_END_"); - col1[i] = (int64_t)((*strTmp1).c_str()); - auto *strTmp2 = new std::string("Col2_START_" + std::to_string(i + 1) + "_END_"); - col2[i] = (int64_t)((*strTmp2).c_str()); - auto *strTmp3 = new std::string("Col3_START_" + std::to_string(i + 1) + "_END_"); - col3[i] = (int64_t)((*strTmp3).c_str()); - auto *strTmp4 = new std::string("Col4_START_" + std::to_string(i + 1) + "_END_"); - col4[i] = (int64_t)((*strTmp4).c_str()); - string_cache_test_.push_back(strTmp1); - string_cache_test_.push_back(strTmp2); - string_cache_test_.push_back(strTmp3); - string_cache_test_.push_back(strTmp4); - } - - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2), - reinterpret_cast(col3), - reinterpret_cast(col4)}; - VectorBatch *in = OckCreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; - delete[] col0; - delete[] col1; - delete[] col2; - delete[] col3; - delete[] col4; - - for (int p = 0; p < string_cache_test_.size(); p++) { - delete string_cache_test_[p]; // 释放内存 - } - return in; -} - -VectorBatch *OckCreateVectorBatch_1fixedCols_withPid(int parNum, int32_t rowNum) -{ - int partitionNum = parNum; - - // gen vectorBatch - const int32_t numCols = 1; - auto *inputTypes = new int32_t[numCols]; - // inputTypes[0] = OMNI_INT; - inputTypes[0] = OMNI_LONG; - - const uint32_t numRows = rowNum; - - std::cout << "gen row " << numRows << std::endl; - // auto *col0 = new int32_t[numRows]; - auto *col1 = new int64_t[numRows]; - for (int i = 0; i < numRows; i++) { - // col0[i] = 0; // i % partitionNum; col1[i] = i + 1; } - int64_t allData[numCols] = {reinterpret_cast(col1)}; - VectorBatch *in = OckCreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; - // delete[] col0; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1); + delete[] col0; delete[] col1; - return in; + return in; } -VectorBatch *OckCreateVectorBatch_3fixedCols_withPid(int parNum, int rowNum) +VectorBatch *OckCreateVectorBatch_5fixedCols_withPid(int parNum, int rowNum) { int partitionNum = parNum; - // gen vectorBatch - const int32_t numCols = 4; - auto *inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_INT; - inputTypes[2] = OMNI_LONG; - inputTypes[3] = OMNI_DOUBLE; + DataTypes inputTypes( + std::vector({ IntType(), BooleanType(), ShortType(), IntType(), LongType(), DoubleType() })); const int32_t numRows = rowNum; - auto *col0 = new int32_t[numRows]; - auto *col1 = new int32_t[numRows]; - auto *col2 = new int64_t[numRows]; - auto *col3 = new double[numRows]; + auto* col0 = new int32_t[numRows]; + auto* col1 = new bool[numRows]; + auto* col2 = new int16_t[numRows]; + auto* col3 = new int32_t[numRows]; + auto* col4 = new int64_t[numRows]; + auto* col5 = new double[numRows]; for (int i = 0; i < numRows; i++) { col0[i] = i % partitionNum; - col1[i] = i + 1; + col1[i] = (i % 2) == 0 ? true : false; col2[i] = i + 1; col3[i] = i + 1; + col4[i] = i + 1; + col5[i] = i + 1; } - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2), - reinterpret_cast(col3)}; - VectorBatch *in = OckCreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2, col3, col4, col5); delete[] col0; delete[] col1; delete[] col2; delete[] col3; - return in; + delete[] col4; + delete[] col5; + return in; } VectorBatch *OckCreateVectorBatch_2dictionaryCols_withPid(int partitionNum) @@ -547,121 +349,121 @@ VectorBatch *OckCreateVectorBatch_2dictionaryCols_withPid(int partitionNum) // construct input data const int32_t dataSize = 6; // prepare data - int32_t data0[dataSize] = {111, 112, 113, 114, 115, 116}; - int64_t data1[dataSize] = {221, 222, 223, 224, 225, 226}; - int64_t data2[dataSize] = {111, 222, 333, 444, 555, 666}; - Decimal128 data3[dataSize] = {Decimal128(0, 1), Decimal128(0, 2), Decimal128(0, 3), Decimal128(0, 4), Decimal128(0, 5), Decimal128(0, 6)}; - void *datas[4] = {data0, data1, data2, data3}; + auto *col0 = new int32_t[dataSize]; + for (int32_t i = 0; i< dataSize; i++) { + col0[i] = (i + 1) % partitionNum; + } + int32_t col1[dataSize] = {111, 112, 113, 114, 115, 116}; + int64_t col2[dataSize] = {221, 222, 223, 224, 225, 226}; + void *datas[2] = {col1, col2}; + DataTypes sourceTypes(std::vector({ IntType(), LongType() })); + int32_t ids[] = {0, 1, 2, 3, 4, 5}; - DataTypes sourceTypes(std::vector({ IntDataType(), LongDataType(), Decimal64DataType(7, 2), Decimal128DataType(38, 2)})); + VectorBatch *vectorBatch = new VectorBatch(dataSize); + auto Vec0 = CreateVector(dataSize, col0); + vectorBatch->Append(Vec0); + auto dicVec0 = CreateDictionaryVector(*sourceTypes.GetType(0), dataSize, ids, dataSize, datas[0]); + auto dicVec1 = CreateDictionaryVector(*sourceTypes.GetType(1), dataSize, ids, dataSize, datas[1]); + vectorBatch->Append(dicVec0); + vectorBatch->Append(dicVec1); - int32_t ids[] = {0, 1, 2, 3, 4, 5}; - auto vectorBatch = new VectorBatch(5, dataSize); - VectorAllocator *allocator = omniruntime::vec::GetProcessGlobalVecAllocator(); - auto intVectorTmp = new IntVector(allocator, 6); - for (int i = 0; i < intVectorTmp->GetSize(); i++) { - intVectorTmp->SetValue(i, (i + 1) % partitionNum); - } - for (int32_t i = 0; i < 5; i++) { - if (i == 0) { - vectorBatch->SetVector(i, intVectorTmp); - } else { - omniruntime::vec::DataType dataType = sourceTypes.Get()[i - 1]; - vectorBatch->SetVector(i, OckCreateDictionaryVector(dataType, dataSize, ids, dataSize, datas[i - 1])); - } - } + delete[] col0; return vectorBatch; } VectorBatch *OckCreateVectorBatch_1decimal128Col_withPid(int partitionNum) { - int32_t ROW_PER_VEC_BATCH = 999; - auto decimal128InputVec = OckbuildVector(Decimal128DataType(38, 2), ROW_PER_VEC_BATCH); - VectorAllocator *allocator = omniruntime::vec::GetProcessGlobalVecAllocator(); - auto *intVectorPid = new IntVector(allocator, ROW_PER_VEC_BATCH); - for (int i = 0; i < intVectorPid->GetSize(); i++) { - intVectorPid->SetValue(i, (i + 1) % partitionNum); + const int32_t numRows = rowNum; + DataTypes inputTypes(std::vector({ IntType(), Decimal128Type(38, 2) })); + + auto *col0 = new int32_t[numRows]; + auto *col1 = new Decimal128[numRows]; + for (int32_t i = 0; i < numRows; i++) { + col0[i] = (i + 1) % partitionNum; + col1[i] = Decimal128(0, 1); } - auto *vecBatch = new VectorBatch(2); - vecBatch->SetVector(0, intVectorPid); - vecBatch->SetVector(1, decimal128InputVec); - return vecBatch; + + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1); + delete[] col0; + delete[] col1; + return in; } VectorBatch *OckCreateVectorBatch_1decimal64Col_withPid(int partitionNum, int rowNum) { - auto decimal64InputVec = OckbuildVector(Decimal64DataType(7, 2), rowNum); - VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); - IntVector *intVectorPid = new IntVector(allocator, rowNum); - for (int i = 0; i < intVectorPid->GetSize(); i++) { - intVectorPid->SetValue(i, (i+1) % partitionNum); + const int32_t numRows = rowNum; + DataTypes inputTypes(std::vector({ IntType(), Decimal64Type(7, 2) })); + + auto *col0 = new int32_t[numRows]; + auto *col1 = new int64_t[numRows]; + for (int32_t i = 0; i < numRows; i++) { + col0[i] = (i + 1) % partitionNum; + col1[i] = 1; } - VectorBatch *vecBatch = new VectorBatch(2); - vecBatch->SetVector(0, intVectorPid); - vecBatch->SetVector(1, decimal64InputVec); - return vecBatch; + + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1); + delete[] col0; + delete[] col1; + return in; } VectorBatch *OckCreateVectorBatch_2decimalCol_withPid(int partitionNum, int rowNum) { - auto decimal64InputVec = OckbuildVector(Decimal64DataType(7, 2), rowNum); - auto decimal128InputVec = OckbuildVector(Decimal128DataType(38, 2), rowNum); - VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); - IntVector *intVectorPid = new IntVector(allocator, rowNum); - for (int i = 0; i < intVectorPid->GetSize(); i++) { - intVectorPid->SetValue(i, (i+1) % partitionNum); + const int32_t numRows = rowNum; + DataTypes inputTypes(std::vector({ IntType(), Decimal64Type(7, 2), Decimal128Type(38, 2) })); + + auto *col0 = new int32_t[numRows]; + auto *col1 = new int64_t[numRows]; + auto *col2 = new Decimal128[numRows]; + for (int32_t i = 0; i < numRows; i++) { + col0[i] = (i + 1) % partitionNum; + col1[i] = 1; + col2[i] = Decimal128(0, 1); } - VectorBatch *vecBatch = new VectorBatch(3); - vecBatch->SetVector(0, intVectorPid); - vecBatch->SetVector(1, decimal64InputVec); - vecBatch->SetVector(2, decimal128InputVec); - return vecBatch; + + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2); + delete[] col0; + delete[] col1; + delete[] col2; + return in; } VectorBatch *OckCreateVectorBatch_someNullRow_vectorBatch() { const int32_t numRows = 6; - int32_t data1[numRows] = {0, 1, 2, 0, 1, 2}; - int64_t data2[numRows] = {0, 1, 2, 3, 4, 5}; - double data3[numRows] = {0.0, 1.1, 2.2, 3.3, 4.4, 5.5}; - std::string data4[numRows] = {"abcde", "fghij", "klmno", "pqrst", "", ""}; - - auto vec0 = OckCreateVector(data1, numRows); - auto vec1 = OckCreateVector(data2, numRows); - auto vec2 = OckCreateVector(data3, numRows); - auto vec3 = OckCreateVarcharVector(VarcharDataType(varcharType), data4, numRows); - for (int i = 0; i < numRows; i = i + 2) { - vec0->SetValueNull(i, false); - vec1->SetValueNull(i, false); - vec2->SetValueNull(i, false); + const int32_t numCols = 6; + bool data0[numRows] = {true, false, true, false, true, false}; + int16_t data1[numRows] = {0, 1, 2, 3, 4, 6}; + int32_t data2[numRows] = {0, 1, 2, 0, 1, 2}; + int64_t data3[numRows] = {0, 1, 2, 3, 4, 5}; + double data4[numRows] = {0.0, 1.1, 2.2, 3.3, 4.4, 5.5}; + std::string data5[numRows] = {"abcde", "fghij", "klmno", "pqrst", "", ""}; + + DataTypes inputTypes( + std::vector({ BooleanType(), ShortType(), IntType(), LongType(), DoubleType(), VarcharType(5) })); + VectorBatch* vecBatch = CreateVectorBatch(inputTypes, numRows, data0, data1, data2, data3, data4, data5); + for (int32_t i = 0; i < numCols; i++) { + for (int32_t j = 0; j < numRows; j = j + 2) { + vecBatch->Get(i)->SetNull(j); + } } - auto *vecBatch = new VectorBatch(4); - vecBatch->SetVector(0, vec0); - vecBatch->SetVector(1, vec1); - vecBatch->SetVector(2, vec2); - vecBatch->SetVector(3, vec3); return vecBatch; } VectorBatch *OckCreateVectorBatch_someNullCol_vectorBatch() { const int32_t numRows = 6; + const int32_t numCols = 4; int32_t data1[numRows] = {0, 1, 2, 0, 1, 2}; int64_t data2[numRows] = {0, 1, 2, 3, 4, 5}; double data3[numRows] = {0.0, 1.1, 2.2, 3.3, 4.4, 5.5}; std::string data4[numRows] = {"abcde", "fghij", "klmno", "pqrst", "", ""}; - auto vec0 = OckCreateVector(data1, numRows); - auto vec1 = OckCreateVector(data2, numRows); - auto vec2 = OckCreateVector(data3, numRows); - auto vec3 = OckCreateVarcharVector(VarcharDataType(varcharType), data4, numRows); - for (int i = 0; i < numRows; i = i + 1) { - vec1->SetValueNull(i); - vec3->SetValueNull(i); + DataTypes inputTypes(std::vector({ IntType(), LongType(), DoubleType(), VarcharType(5) })); + VectorBatch* vecBatch = CreateVectorBatch(inputTypes, numRows, data1, data2, data3, data4); + for (int32_t i = 0; i < numCols; i = i + 2) { + for (int32_t j = 0; j < numRows; j++) { + vecBatch->Get(i)->SetNull(j); + } } - auto *vecBatch = new VectorBatch(4); - vecBatch->SetVector(0, vec0); - vecBatch->SetVector(1, vec1); - vecBatch->SetVector(2, vec2); - vecBatch->SetVector(3, vec3); return vecBatch; } diff --git a/omnioperator/omniop-spark-extension-ock/cpp/test/utils/ock_test_utils.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/ock_test_utils.h similarity index 52% rename from omnioperator/omniop-spark-extension-ock/cpp/test/utils/ock_test_utils.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/ock_test_utils.h index 9695a5ad6f1015ed230e426c20b1981b98891a84..6ffb74492d39dd81c17c6b7c21fb1a9b557c3085 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/test/utils/ock_test_utils.h +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/ock_test_utils.h @@ -11,7 +11,7 @@ #include #include #include - +#include #include "../../src/jni/concurrent_map.h" #define private public static const int varcharType = 5; @@ -22,29 +22,29 @@ static ock::dopspark::ConcurrentMap> static std::string Ocks_shuffle_tests_dir = "/tmp/OckshuffleTests"; -VectorBatch *OckCreateInputData(const int32_t numRows, const int32_t numCols, int32_t *inputTypeIds, int64_t *allData); +std::unique_ptr CreateVector(DataType &dataType, int32_t rowCount, va_list &args); + +VectorBatch *OckCreateInputData(const DataTypes &types, int32_t rowCount, ...); -Vector *OckbuildVector(const DataType &aggType, int32_t rowNumber); +VectorBatch *OckCreateVectorBatch(const DataTypes &types, int32_t rowCount, ...); -Vector *OckNewbuildVector(const DataTypeId &typeId, int32_t rowNumber); +BaseVector *OckNewbuildVector(const DataTypeId &typeId, int32_t rowNumber); + +VectorBatch *OckCreateVectorBatch_4varcharCols_withPid(int parNum, int rowNum); VectorBatch *OckCreateVectorBatch_1row_varchar_withPid(int pid, const std::string &inputChar); VectorBatch *OckCreateVectorBatch_4col_withPid(int parNum, int rowNum); -VectorBatch *OckCreateVectorBatch_1longCol_withPid(int parNum, int rowNum); - VectorBatch *OckCreateVectorBatch_2column_1row_withPid(int pid, std::string strVar, int intVar); -VectorBatch *OckCreateVectorBatch_4varcharCols_withPid(int parNum, int rowNum); - -VectorBatch *OckCreateVectorBatch_3fixedCols_withPid(int parNum, int rowNum); +VectorBatch *OckCreateVectorBatch_5fixedCols_withPid(int parNum, int rowNum); -VectorBatch *OckCreateVectorBatch_1fixedCols_withPid(int parNum, int32_t rowNum); +VectorBatch *OckCreateVectorBatch_1fixedCols_withPid(int parNum, int32_t rowNum, DataTypePtr fixColType); VectorBatch *OckCreateVectorBatch_2dictionaryCols_withPid(int partitionNum); -VectorBatch *OckCreateVectorBatch_1decimal128Col_withPid(int partitionNum); +VectorBatch *OckCreateVectorBatch_1decimal128Col_withPid(int partitionNum, int rowNum); VectorBatch *OckCreateVectorBatch_1decimal64Col_withPid(int partitionNum, int rowNum); @@ -67,6 +67,53 @@ void OckTest_splitter_stop(long splitter_id); void OckTest_splitter_close(long splitter_id); +template BaseVector *CreateVector(int32_t length, T *values) +{ + std::unique_ptr> vector = std::make_unique>(length); + for (int32_t i = 0; i < length; i++) { + vector->SetValue(i, values[i]); + } + return vector; +} + +template +BaseVector *CreateFlatVector(int32_t length, va_list &args) +{ + using namespace omniruntime::type; + using T = typename NativeType::type; + using VarcharVector = Vector>; + if constexpr (std::is_same_v) { + VarcharVector *vector = new VarcharVector(length); + std::string *str = va_arg(args, std::string *); + for (int32_t i = 0; i < length; i++) { + std::string_view value(str[i].data(), str[i].length()); + vector->SetValue(i, value); + } + return vector; + } else { + Vector *vector = new Vector(length); + T *value = va_arg(args, T *); + for (int32_t i = 0; i < length; i++) { + vector->SetValue(i, value[i]); + } + return vector; + } +} + +template +BaseVector *CreateDictionary(BaseVector *vector, int32_t *ids, int32_t size) +{ + using T = typename NativeType::type; + if constexpr (std::is_same_v) { + return VectorHelper::CreateStringDictionary(ids, size, + reinterpret_cast> *>(vector)); + } else { + return VectorHelper::CreateDictionary(ids, size, reinterpret_cast *>(vector)); + } +} + + + template T *OckCreateVector(V *values, int32_t length) { VectorAllocator *vecAllocator = VectorAllocator::GetGlobalAllocator(); diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/pom.xml b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..b2fdb093d1a890acfe16eb154c522d1af04baf0e --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/pom.xml @@ -0,0 +1,122 @@ + + + 4.0.0 + + com.huawei.ock + omniop-spark-extension-ock + 23.0.0 + + + cpp/ + cpp/build/releases/ + FALSE + 0.6.1 + + + ock-omniop-shuffle-manager + jar + Huawei Open Computing Kit for Spark, shuffle manager + 23.0.0 + + + + ${project.artifactId}-${project.version}-for-${input.version} + + + ${cpp.build.dir} + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + + net.alchim31.maven + scala-maven-plugin + ${scala.plugin.version} + + all + + + + + compile + testCompile + + + + -dependencyfile + ${project.build.directory}/.scala_dependencies + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + org.apache.maven.plugins + maven-compiler-plugin + 3.1 + + 8 + 8 + true + + -Xlint:all + + + + + exec-maven-plugin + org.codehaus.mojo + 3.0.0 + + + Build CPP + generate-resources + + exec + + + bash + + ${cpp.dir}/build.sh + ${plugin.cpp.test} + + + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + ${protobuf.maven.version} + + ${project.basedir}/../cpp/src/proto + + + + + compile + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + ${maven.plugin.version} + + + + + \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/NativeLoader.java b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/jni/NativeLoader.java similarity index 100% rename from omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/NativeLoader.java rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/jni/NativeLoader.java diff --git a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java similarity index 96% rename from omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java index ec294bdbf2208361846b4576ba0559abb9cfabc2..462ad9d105a54374bc867a9d83e45133fc238332 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java @@ -150,8 +150,18 @@ public class OckShuffleJniReader { nativeCopyVecDataInVB(nativeReader, dstVec.getNativeVector(), colIndex); } + /** + * close reader. + * + */ + public void doClose() { + close(nativeReader); + } + private native long make(int[] typeIds); + private native long close(long readerId); + private native int nativeGetVectorBatch(long readerId, long vbDataAddr, int capacity, int maxRow, int maxDataSize, Long rowCnt); diff --git a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniWriter.java b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniWriter.java similarity index 100% rename from omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniWriter.java rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniWriter.java diff --git a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/serialize/OckShuffleDataSerializer.java b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/serialize/OckShuffleDataSerializer.java similarity index 100% rename from omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/serialize/OckShuffleDataSerializer.java rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/serialize/OckShuffleDataSerializer.java diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/spark/serialize/OckColumnarBatchSerialize.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/com/huawei/ock/spark/serialize/OckColumnarBatchSerialize.scala similarity index 100% rename from omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/spark/serialize/OckColumnarBatchSerialize.scala rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/com/huawei/ock/spark/serialize/OckColumnarBatchSerialize.scala diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBlockResolver.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBlockResolver.scala similarity index 100% rename from omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBlockResolver.scala rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBlockResolver.scala diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala similarity index 95% rename from omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala index dc7e081555dfed6646beed6b85fc1f8356b8aa86..d751679e5af9924702ae30127251621f9e93a7fc 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala @@ -51,6 +51,9 @@ class OckColumnarShuffleBufferIterator[T]( NativeShuffle.destroyMapTaskInfo(mapTaskToHostInfo.getNativeObjHandle) mapTaskToHostInfo.setNativeObjHandle(0) } + blobMap.values.foreach(reader => { + reader.doClose() + }) } private[this] def throwFetchException(fetchError: FetchError): Unit = { @@ -84,7 +87,7 @@ class OckColumnarShuffleBufferIterator[T]( // create buffers, or blobIds // use bagName, numBuffers and bufferSize to create buffers in low level if (totalFetchNum != 0) { - NativeShuffle.shuffleStreamReadStart(sequenceId) + NativeShuffle.shuffleStreamReadStart(sequenceId, endPartition) hasBlob = true } diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleHandle.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleHandle.scala similarity index 100% rename from omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleHandle.scala rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleHandle.scala diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleManager.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleManager.scala similarity index 94% rename from omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleManager.scala rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleManager.scala index 3457f0da62f4db1ae14e614d8925ff2089e1d256..f7f07fbb2483b588e4e1cef017a378854beaad60 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleManager.scala +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleManager.scala @@ -4,7 +4,7 @@ package org.apache.spark.shuffle.ock -import com.huawei.ock.ucache.common.exception.ApplicationException +import com.huawei.ock.common.exception.ApplicationException import com.huawei.ock.ucache.shuffle.NativeShuffle import org.apache.spark._ import org.apache.spark.executor.TempShuffleReadMetrics @@ -32,14 +32,14 @@ class OckColumnarShuffleManager(conf: SparkConf) extends ColumnarShuffleManager var appId = "" var listenFlg: Boolean = false var isOckBroadcast: Boolean = ockConf.isOckBroadcast - var heartBeatFlag = false + @volatile var heartBeatFlag: AtomicBoolean = new AtomicBoolean(false) val applicationDefaultAttemptId = "1"; if (ockConf.excludeUnavailableNodes && ockConf.appId == "driver") { OCKScheduler.waitAndBlacklistUnavailableNode(conf) } - OCKFunctions.shuffleInitialize(ockConf, isOckBroadcast) + OCKFunctions.shuffleInitialize(ockConf) val isShuffleCompress: Boolean = conf.get(config.SHUFFLE_COMPRESS) val compressCodec: String = conf.get(IO_COMPRESSION_CODEC); OCKFunctions.setShuffleCompress(OckColumnarShuffleManager.isCompress(conf), compressCodec) @@ -63,8 +63,7 @@ class OckColumnarShuffleManager(conf: SparkConf) extends ColumnarShuffleManager tokenCode = OckColumnarShuffleManager.registerShuffle(shuffleId, dependency.partitioner.numPartitions, conf, ockConf) } - if (!heartBeatFlag && ockConf.appId == "driver") { - heartBeatFlag = true + if (ockConf.appId == "driver" && !heartBeatFlag.getAndSet(true)) { OCKFunctions.tryStartHeartBeat(this, appId) } @@ -187,7 +186,6 @@ private[spark] object OckColumnarShuffleManager extends Logging { logWarning("failed to change externalShuffleServiceEnabled in block manager," + " maybe ockd could not be able to recover in shuffle process") } - conf.set(config.SHUFFLE_SERVICE_ENABLED, true) } // generate token code. Need 32bytes. OCKFunctions.getToken(ockConf.isIsolated) diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleReader.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleReader.scala similarity index 100% rename from omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleReader.scala rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleReader.scala diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleWriter.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleWriter.scala similarity index 94% rename from omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleWriter.scala rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleWriter.scala index e7aaf0fdf7a21abf1737d5280a5d83733cf9d416..83264792d8eca965c3e1069bdb4e3c2dfec032b5 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleWriter.scala +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleWriter.scala @@ -7,6 +7,7 @@ package org.apache.spark.shuffle.ock import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs import com.huawei.boostkit.spark.vectorized.SplitResult import com.huawei.ock.spark.jni.OckShuffleJniWriter +import com.huawei.ock.ucache.shuffle.NativeShuffle import nova.hetu.omniruntime.vector.VecBatch import org.apache.spark.internal.Logging import org.apache.spark.scheduler.MapStatus @@ -140,6 +141,7 @@ class OckColumnarShuffleWriter[K, V]( } else { stopping = true if (success) { + NativeShuffle.shuffleStageSetShuffleId("Spark_"+applicationId, context.stageId(), handle.shuffleId) Option(mapStatus) } else { None diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/pom.xml b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..345504ed5fa1b2fd353f754cea7a3a2efe9492c0 --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/pom.xml @@ -0,0 +1,138 @@ + + + 4.0.0 + + + 3.1.2 + 2.12.10 + 2.12 + 3.2.3 + org.apache.spark + spark-3.1 + 3.2.0 + 3.1.1 + 23.0.0 + + + com.huawei.ock + ock-omniop-tuning + jar + Huawei Open Computing Kit for Spark, BoostTuning for OmniOperator + 23.0.0 + + + + org.scala-lang + scala-library + ${scala.version} + provided + + + ${spark.groupId} + spark-core_${scala.compat.version} + ${spark.version} + provided + + + ${spark.groupId} + spark-catalyst_${scala.compat.version} + ${spark.version} + provided + + + ${spark.groupId} + spark-sql_${scala.compat.version} + ${spark.version} + provided + + + com.huawei.ock + ock-adaptive-tuning + ${global.version} + + + com.huawei.ock + ock-tuning-sdk + ${global.version} + + + com.huawei.ock + ock-shuffle-sdk + ${global.version} + + + com.huawei.boostkit + boostkit-omniop-bindings + 1.3.0 + + + com.huawei.kunpeng + boostkit-omniop-spark + 3.1.1-1.3.0 + + + org.scalatest + scalatest_${scala.compat.version} + ${scalaTest.version} + test + + + + + ${project.artifactId}-${project.version}-for-${input.version} + src/main/scala + + + + net.alchim31.maven + scala-maven-plugin + ${scala.plugin.version} + + all + + + + + compile + testCompile + + + + -dependencyfile + ${project.build.directory}/.scala_dependencies + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + org.apache.maven.plugins + maven-compiler-plugin + 3.1 + + 8 + 8 + true + + -Xlint:all + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + ${maven.plugin.version} + + + + + \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/OmniOpBoostTuningExtension.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/OmniOpBoostTuningExtension.scala new file mode 100644 index 0000000000000000000000000000000000000000..e63ef56726f91bcb5abeaecbf5c5c865e1c6280c --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/OmniOpBoostTuningExtension.scala @@ -0,0 +1,18 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock + +import org.apache.spark.SparkContext +import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.execution.adaptive.ock.rule._ + +class OmniOpBoostTuningExtension extends (SparkSessionExtensions => Unit) { + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectQueryStagePrepRule(_ => BoostTuningQueryStagePrepRule()) + extensions.injectColumnar(_ => OmniOpBoostTuningColumnarRule( + OmniOpBoostTuningPreColumnarRule(), OmniOpBoostTuningPostColumnarRule())) + SparkContext.getActive.get.addSparkListener(new BoostTuningListener()) + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/common/OmniOpBoostTuningDefine.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/common/OmniOpBoostTuningDefine.scala new file mode 100644 index 0000000000000000000000000000000000000000..42d415bb685823e84e311afef161b43edf6622fa --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/common/OmniOpBoostTuningDefine.scala @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock.common + +import com.huawei.boostkit.spark.ColumnarPluginConfig +import org.apache.spark.SparkEnv + +object OmniOpDefine { + final val COLUMNAR_SHUFFLE_MANAGER_DEFINE = "org.apache.spark.shuffle.sort.ColumnarShuffleManager" + + final val COLUMNAR_SORT_SPILL_ROW_THRESHOLD = "spark.omni.sql.columnar.sortSpill.rowThreshold" + final val COLUMNAR_SORT_SPILL_ROW_BASED_ENABLED = "spark.omni.sql.columnar.sortSpill.enabled" +} + +object OmniOCKShuffleDefine { + final val OCK_COLUMNAR_SHUFFLE_MANAGER_DEFINE = "org.apache.spark.shuffle.ock.OckColumnarShuffleManager" +} + +object OmniRuntimeConfiguration { + val enableColumnarShuffle: Boolean = ColumnarPluginConfig.getSessionConf.enableColumnarShuffle + val OMNI_SPILL_ROWS: Long = SparkEnv.get.conf.getLong(OmniOpDefine.COLUMNAR_SORT_SPILL_ROW_THRESHOLD, Integer.MAX_VALUE) + val OMNI_SPILL_ROW_ENABLED: Boolean = SparkEnv.get.conf.getBoolean(OmniOpDefine.COLUMNAR_SORT_SPILL_ROW_BASED_ENABLED, defaultValue = true) +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/BoostTuningColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/BoostTuningColumnarShuffleExchangeExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..c0b00da7628349844890ed872fad98828de8e1e9 --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/BoostTuningColumnarShuffleExchangeExec.scala @@ -0,0 +1,207 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive.ock.exchange + +import com.huawei.boostkit.spark.ColumnarPluginConfig +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ +import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer + +import nova.hetu.omniruntime.`type`.DataType + +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.adaptive.ock.common.BoostTuningLogger._ +import org.apache.spark.sql.execution.adaptive.ock.common.BoostTuningUtil._ +import org.apache.spark.sql.execution.adaptive.ock.exchange.estimator._ +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.execution.util.MergeIterator +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.{MapOutputStatistics, ShuffleDependency} +import org.apache.spark.util.MutablePair + +import scala.concurrent.Future + +case class BoostTuningColumnarShuffleExchangeExec( + override val outputPartitioning: Partitioning, + child: SparkPlan, + shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS, + @transient context: PartitionContext) extends BoostTuningShuffleExchangeLike{ + + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics: Map[String, SQLMetric] = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "bytesSpilled" -> SQLMetrics.createSizeMetric(sparkContext, "shuffle bytes spilled"), + "splitTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "totaltime_split"), + "spillTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle spill time"), + "compressTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "totaltime_compress"), + "avgReadBatchNumRows" -> SQLMetrics + .createAverageMetric(sparkContext, "avg read batch num rows"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs"), + "bypassVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of bypass vecBatchs"), + "numOutputRows" -> SQLMetrics + .createMetric(sparkContext, "number of output rows")) ++ readMetrics ++ writeMetrics + + override def nodeName: String = "BoostTuningOmniColumnarShuffleExchange" + + override def getContext: PartitionContext = context + + override def getDependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = boostTuningColumnarShuffleDependency + + override def getUpStreamDataSize: Long = collectUpStreamInputDataSize(this.child) + + override def getPartitionEstimators: Seq[PartitionEstimator] = estimators + + @transient val helper: BoostTuningShuffleExchangeHelper = + new BoostTuningColumnarShuffleExchangeHelper(this, sparkContext) + + @transient lazy val estimators: Seq[PartitionEstimator] = Seq( + UpStreamPartitionEstimator(), + ColumnarSamplePartitionEstimator(helper.executionMem)) ++ Seq( + SinglePartitionEstimator(), + ColumnarElementsForceSpillPartitionEstimator() + ) + + override def supportsColumnar: Boolean = true + + val serializer: Serializer = new ColumnarBatchSerializer( + longMetric("avgReadBatchNumRows"), + longMetric("numOutputRows")) + + @transient lazy val inputColumnarRDD: RDD[ColumnarBatch] = child.executeColumnar() + + // 'mapOutputStatisticsFuture' is only needed when enable AQE. + @transient override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { + if (inputColumnarRDD.getNumPartitions == 0) { + context.setSelfAndDepPartitionNum(outputPartitioning.numPartitions) + Future.successful(null) + } else { + omniAdaptivePartitionWithMapOutputStatistics() + } + } + + private def omniAdaptivePartitionWithMapOutputStatistics(): Future[MapOutputStatistics] = { + helper.cachedSubmitMapStage() match { + case Some(f) => return f + case _ => + } + + helper.onlineSubmitMapStage() match { + case f: Future[MapOutputStatistics] => f + case _ => Future.failed(null) + } + } + + override def numMappers: Int = boostTuningColumnarShuffleDependency.rdd.getNumPartitions + + override def numPartitions: Int = boostTuningColumnarShuffleDependency.partitioner.numPartitions + + override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[InternalRow] = { + throw new IllegalArgumentException("Failed to getShuffleRDD, exec should use ColumnarBatch but not InternalRow") + } + + override def runtimeStatistics: Statistics = { + val dataSize = metrics("dataSize").value + val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value + Statistics(dataSize, Some(rowCount)) + } + + @transient + lazy val boostTuningColumnarShuffleDependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + val partitionInitTime = System.currentTimeMillis() + val newOutputPartitioning = helper.replacePartitionWithNewNum() + val partitionReadyTime = System.currentTimeMillis() + val dep = ColumnarShuffleExchangeExec.prepareShuffleDependency( + inputColumnarRDD, + child.output, + newOutputPartitioning, + serializer, + writeMetrics, + longMetric("dataSize"), + longMetric("bytesSpilled"), + longMetric("numInputRows"), + longMetric("splitTime"), + longMetric("spillTime")) + val dependencyReadyTime = System.currentTimeMillis() + TLogInfo(s"BoostTuningShuffleExchange $id input partition ${inputColumnarRDD.getNumPartitions}" + + s" modify ${if (helper.isAdaptive) "adaptive" else "global"}" + + s" partitionNum ${outputPartitioning.numPartitions} -> ${newOutputPartitioning.numPartitions}" + + s" partition modify cost ${partitionReadyTime - partitionInitTime} ms" + + s" dependency prepare cost ${dependencyReadyTime - partitionReadyTime} ms") + dep + } + + var cachedShuffleRDD: ShuffledColumnarRDD = _ + + override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException() + } + + def buildCheck(): Unit = { + val inputTypes = new Array[DataType](child.output.size) + child.output.zipWithIndex.foreach { + case (attr, i) => + inputTypes(i) = sparkTypeToOmniType(attr.dataType, attr.metadata) + } + + outputPartitioning match { + case HashPartitioning(expressions, numPartitions) => + val genHashExpressionFunc = ColumnarShuffleExchangeExec.genHashExpr() + val hashJSonExpressions = genHashExpressionFunc(expressions, numPartitions, ColumnarShuffleExchangeExec.defaultMm3HashSeed, child.output) + if (!isSimpleColumn(hashJSonExpressions)) { + checkOmniJsonWhiteList("", Array(hashJSonExpressions)) + } + case _ => + } + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + if (cachedShuffleRDD == null) { + cachedShuffleRDD = new ShuffledColumnarRDD(boostTuningColumnarShuffleDependency, readMetrics) + } + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val enableShuffleBatchMerge: Boolean = columnarConf.enableShuffleBatchMerge + if (enableShuffleBatchMerge) { + cachedShuffleRDD.mapPartitionsWithIndexInternal { (index, iter) => + new MergeIterator(iter, + StructType.fromAttributes(child.output), + longMetric("numMergedVecBatchs"), + longMetric("bypassVecBatchs")) + } + } else { + cachedShuffleRDD + } + } + + protected def withNewChildInternal(newChild: SparkPlan): BoostTuningColumnarShuffleExchangeExec = { + copy(child = newChild) + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/BoostTuningColumnarShuffleExchangeHelper.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/BoostTuningColumnarShuffleExchangeHelper.scala new file mode 100644 index 0000000000000000000000000000000000000000..bb3838d723e4f0800b702727ec1aa8c5605f375a --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/BoostTuningColumnarShuffleExchangeHelper.scala @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock.exchange + +import org.apache.spark.SparkContext +import org.apache.spark.sql.execution.adaptive.ock.common.OmniRuntimeConfiguration._ +import org.apache.spark.sql.execution.adaptive.ock.common.RuntimeConfiguration._ +import org.apache.spark.sql.execution.adaptive.ock.common._ +import org.apache.spark.sql.execution.adaptive.ock.memory._ + +import java.util + +class BoostTuningColumnarShuffleExchangeHelper(exchange: BoostTuningShuffleExchangeLike, sparkContext: SparkContext) + extends BoostTuningShuffleExchangeHelper(exchange, sparkContext) { + + override val executionMem: Long = shuffleManager match { + case OCKBoostShuffleDefine.OCK_SHUFFLE_MANAGER_DEFINE => + BoostShuffleExecutionModel().apply() + case OmniOpDefine.COLUMNAR_SHUFFLE_MANAGER_DEFINE => + ColumnarExecutionModel().apply() + case OmniOCKShuffleDefine.OCK_COLUMNAR_SHUFFLE_MANAGER_DEFINE => + ColumnarExecutionModel().apply() + case _ => + OriginExecutionModel().apply() + } + + override protected def fillInput(input: util.LinkedHashMap[String, String]): Unit = { + input.put("executionSize", executionMem.toString) + input.put("upstreamDataSize", exchange.getUpStreamDataSize.toString) + input.put("partitionRatio", initPartitionRatio.toString) + var spillThreshold = if (OMNI_SPILL_ROW_ENABLED) { + Math.min(OMNI_SPILL_ROWS, numElementsForceSpillThreshold) + } else { + numElementsForceSpillThreshold + } + if (spillThreshold == Integer.MAX_VALUE) { + spillThreshold = -1 + } + + input.put("elementSpillThreshold", spillThreshold.toString) + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/estimator/ColumnarElementsForceSpillPartitionEstimator.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/estimator/ColumnarElementsForceSpillPartitionEstimator.scala new file mode 100644 index 0000000000000000000000000000000000000000..984537352a026343a664ea87e7a1aa2806b0b0a8 --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/estimator/ColumnarElementsForceSpillPartitionEstimator.scala @@ -0,0 +1,41 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock.exchange.estimator + +import org.apache.spark.sql.execution.adaptive.ock.common.OmniRuntimeConfiguration._ +import org.apache.spark.sql.execution.adaptive.ock.common.RuntimeConfiguration._ +import org.apache.spark.sql.execution.adaptive.ock.exchange._ +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + +case class ColumnarElementsForceSpillPartitionEstimator() extends PartitionEstimator { + + override def estimatorType: EstimatorType = ElementNumBased + + override def apply(exchange: ShuffleExchangeLike): Option[Int] = { + if (!sampleEnabled) { + return None + } + + if (!OMNI_SPILL_ROW_ENABLED && numElementsForceSpillThreshold == Integer.MAX_VALUE) { + return None + } + + val spillMinThreshold = if (OMNI_SPILL_ROW_ENABLED) { + Math.min(OMNI_SPILL_ROWS, numElementsForceSpillThreshold) + } else { + numElementsForceSpillThreshold + } + + exchange match { + case ex: BoostTuningColumnarShuffleExchangeExec => + val rowCount = ex.inputColumnarRDD + .sample(withReplacement = false, sampleRDDFraction) + .map(cb => cb.numRows()).first() + Some((initPartitionRatio * rowCount / spillMinThreshold).toInt) + case _ => + None + } + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/estimator/ColumnarSamplePartitionEstimator.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/estimator/ColumnarSamplePartitionEstimator.scala new file mode 100644 index 0000000000000000000000000000000000000000..c336ffee383a0bcd98774851d36af96cf4b65d0a --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/estimator/ColumnarSamplePartitionEstimator.scala @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock.exchange.estimator + +import com.huawei.boostkit.spark.util.OmniAdaptorUtil + +import org.apache.spark.sql.execution.adaptive.ock.common.RuntimeConfiguration._ +import org.apache.spark.sql.execution.adaptive.ock.exchange.BoostTuningColumnarShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + +case class ColumnarSamplePartitionEstimator(executionMem: Long) extends PartitionEstimator { + + override def estimatorType: EstimatorType = DataSizeBased + + override def apply(exchange: ShuffleExchangeLike): Option[Int] = { + if (!sampleEnabled) { + return None + } + + exchange match { + case ex: BoostTuningColumnarShuffleExchangeExec => + val inputPartitionNum = ex.inputColumnarRDD.getNumPartitions + val sampleRDD = ex.inputColumnarRDD + .sample(withReplacement = false, sampleRDDFraction) + .map(cb => OmniAdaptorUtil.transColBatchToOmniVecs(cb).map(_.getCapacityInBytes).sum) + Some(SamplePartitionEstimator(executionMem).sampleAndGenPartitionNum(ex, inputPartitionNum, sampleRDD)) + case _ => + None + } + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/memory/ColumnarExecutionModel.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/memory/ColumnarExecutionModel.scala new file mode 100644 index 0000000000000000000000000000000000000000..e28db0bf9e2539b8b8aacb71f28856bc582832bd --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/memory/ColumnarExecutionModel.scala @@ -0,0 +1,30 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock.memory + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.config +import org.apache.spark.sql.execution.adaptive.ock.common.BoostTuningLogger._ +import org.apache.spark.sql.execution.adaptive.ock.common.RuntimeConfiguration._ + +case class ColumnarExecutionModel() extends ExecutionModel { + override def apply(): Long = { + val systemMem = executorMemory + val executorCores = SparkEnv.get.conf.get(config.EXECUTOR_CORES).toLong + val reservedMem = SparkEnv.get.conf.getLong("spark.testing.reservedMemory", 300 * 1024 * 1024) + val usableMem = systemMem - reservedMem + val shuffleMemFraction = SparkEnv.get.conf.get(config.MEMORY_FRACTION) * + (1 - SparkEnv.get.conf.get(config.MEMORY_STORAGE_FRACTION)) + val offHeapMem = if (offHeapEnabled) { + offHeapSize + } else { + 0 + } + val finalMem = ((usableMem * shuffleMemFraction + offHeapMem) / executorCores).toLong + TLogDebug(s"ExecutorMemory is $systemMem reserved $reservedMem offHeapMem is $offHeapMem" + + s" shuffleMemFraction is $shuffleMemFraction, execution memory of executor is $finalMem") + finalMem + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/reader/BoostTuningColumnarCustomShuffleReaderExec.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/reader/BoostTuningColumnarCustomShuffleReaderExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..eb38f770907dde2e9882bbd73a0985978bb29aa6 --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/reader/BoostTuningColumnarCustomShuffleReaderExec.scala @@ -0,0 +1,233 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive.ock.reader + +import com.huawei.boostkit.spark.ColumnarPluginConfig +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.spark.sql.execution.adaptive.ock.exchange.BoostTuningColumnarShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.util.MergeIterator +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +import scala.collection.mutable.ArrayBuffer + + +/** + * A wrapper of shuffle query stage, which follows the given partition arrangement. + * + * @param child It is usually `ShuffleQueryStageExec`, but can be the shuffle exchange + * node during canonicalization. + * @param partitionSpecs The partition specs that defines the arrangement. + */ +case class BoostTuningColumnarCustomShuffleReaderExec( + child: SparkPlan, + partitionSpecs: Seq[ShufflePartitionSpec]) + extends UnaryExecNode { + // If this reader is to read shuffle files locally, then all partition specs should be + // `PartialMapperPartitionSpec`. + if (partitionSpecs.exists(_.isInstanceOf[PartialMapperPartitionSpec])) { + assert(partitionSpecs.forall(_.isInstanceOf[PartialMapperPartitionSpec])) + } + + override def nodeName: String = "BoostTuningOmniColumnarCustomShuffleReaderExec" + + override def supportsColumnar: Boolean = true + + override def output: Seq[Attribute] = child.output + override lazy val outputPartitioning: Partitioning = { + // If it is a local shuffle reader with one mapper per task, then the output partitioning is + // the same as the plan before shuffle. + if (partitionSpecs.nonEmpty && + partitionSpecs.forall(_.isInstanceOf[PartialMapperPartitionSpec]) && + partitionSpecs.map(_.asInstanceOf[PartialMapperPartitionSpec].mapIndex).toSet.size == + partitionSpecs.length) { + child match { + case ShuffleQueryStageExec(_, s: ShuffleExchangeLike) => + s.child.outputPartitioning + case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeLike)) => + s.child.outputPartitioning match { + case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning] + case other => other + } + case _ => + throw new IllegalStateException("operating on canonicalization plan") + } + } else { + UnknownPartitioning(partitionSpecs.length) + } + } + + override def stringArgs: Iterator[Any] = { + val desc = if (isLocalReader) { + "local" + } else if (hasCoalescedPartition && hasSkewedPartition) { + "coalesced and skewed" + } else if (hasCoalescedPartition) { + "coalesced" + } else if (hasSkewedPartition) { + "skewed" + } else { + "" + } + Iterator(desc) + } + + def hasCoalescedPartition: Boolean = + partitionSpecs.exists(_.isInstanceOf[CoalescedPartitionSpec]) + + def hasSkewedPartition: Boolean = + partitionSpecs.exists(_.isInstanceOf[PartialReducerPartitionSpec]) + + def isLocalReader: Boolean = + partitionSpecs.exists(_.isInstanceOf[PartialMapperPartitionSpec]) + + private def shuffleStage = child match { + case stage: ShuffleQueryStageExec => Some(stage) + case _ => None + } + + @transient private lazy val partitionDataSizes: Option[Seq[Long]] = { + if (partitionSpecs.nonEmpty && !isLocalReader && shuffleStage.get.mapStats.isDefined) { + val bytesByPartitionId = shuffleStage.get.mapStats.get.bytesByPartitionId + Some(partitionSpecs.map { + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => + startReducerIndex.until(endReducerIndex).map(bytesByPartitionId).sum + case p: PartialReducerPartitionSpec => p.dataSize + case p => throw new IllegalStateException("unexpected " + p) + }) + } else { + None + } + } + + private def sendDriverMetrics(): Unit = { + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + val driverAccumUpdates = ArrayBuffer.empty[(Long, Long)] + + val numPartitionsMetric = metrics("numPartitions") + numPartitionsMetric.set(partitionSpecs.length) + driverAccumUpdates += (numPartitionsMetric.id -> partitionSpecs.length.toLong) + + if (hasSkewedPartition) { + val skewedSpecs = partitionSpecs.collect { + case p: PartialReducerPartitionSpec => p + } + + val skewedPartitions = metrics("numSkewedPartitions") + val skewedSplits = metrics("numSkewedSplits") + + val numSkewedPartitions = skewedSpecs.map(_.reducerIndex).distinct.length + val numSplits = skewedSpecs.length + + skewedPartitions.set(numSkewedPartitions) + driverAccumUpdates += (skewedPartitions.id -> numSkewedPartitions) + + skewedSplits.set(numSplits) + driverAccumUpdates += (skewedSplits.id -> numSplits) + } + + partitionDataSizes.foreach { dataSizes => + val partitionDataSizeMetrics = metrics("partitionDataSize") + driverAccumUpdates ++= dataSizes.map(partitionDataSizeMetrics.id -> _) + // Set sum value to "partitionDataSize" metric. + partitionDataSizeMetrics.set(dataSizes.sum) + } + + SQLMetrics.postDriverMetricsUpdatedByValue(sparkContext, executionId, driverAccumUpdates.toSeq) + } + + override lazy val metrics: Map[String, SQLMetric] = { + if (shuffleStage.isDefined) { + Map( + "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs"), + "bypassVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of bypass vecBatchs"), + "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions")) ++ { + if (isLocalReader) { + // We split the mapper partition evenly when creating local shuffle reader, so no + // data size info is available. + Map.empty + } else { + Map("partitionDataSize" -> + SQLMetrics.createSizeMetric(sparkContext, "partition data size")) + } + } ++ { + if (hasSkewedPartition) { + Map("numSkewedPartitions" -> + SQLMetrics.createMetric(sparkContext, "number of skewed partitions"), + "numSkewedSplits" -> + SQLMetrics.createMetric(sparkContext, "number of skewed partition splits")) + } else { + Map.empty + } + } + } else { + // It's a canonicalized plan, no need to report metrics. + Map.empty + } + } + + private var cachedShuffleRDD: RDD[ColumnarBatch] = null + + private lazy val shuffleRDD: RDD[_] = { + sendDriverMetrics() + if (cachedShuffleRDD == null) { + cachedShuffleRDD = child match { + case stage: ShuffleQueryStageExec => + new ShuffledColumnarRDD( + stage.shuffle + .asInstanceOf[BoostTuningColumnarShuffleExchangeExec] + .boostTuningColumnarShuffleDependency, + stage.shuffle.asInstanceOf[BoostTuningColumnarShuffleExchangeExec].readMetrics, + partitionSpecs.toArray) + case _ => + throw new IllegalStateException("operating on canonicalized plan") + } + } + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val enableShuffleBatchMerge: Boolean = columnarConf.enableShuffleBatchMerge + if (enableShuffleBatchMerge) { + cachedShuffleRDD.mapPartitionsWithIndexInternal { (index, iter) => + new MergeIterator(iter, + StructType.fromAttributes(child.output), + longMetric("numMergedVecBatchs"), + longMetric("bypassVecBatchs")) + } + } else { + cachedShuffleRDD + } + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException(s"This operator doesn't support doExecute().") + } + + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { + shuffleRDD.asInstanceOf[RDD[ColumnarBatch]] + } +} diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/rule/OmniOpBoostTuningColumnarRule.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/rule/OmniOpBoostTuningColumnarRule.scala new file mode 100644 index 0000000000000000000000000000000000000000..c270a567142d0da6bec16c17d3169b8ba38c2756 --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/rule/OmniOpBoostTuningColumnarRule.scala @@ -0,0 +1,155 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock.rule + +import com.huawei.boostkit.spark.ColumnarPluginConfig +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.ock.BoostTuningQueryManager +import org.apache.spark.sql.execution.adaptive.ock.common.BoostTuningLogger.TLogWarning +import org.apache.spark.sql.execution.adaptive.ock.common.BoostTuningUtil.{getQueryExecutionId, normalizedSparkPlan} +import org.apache.spark.sql.execution.adaptive.ock.common.OmniRuntimeConfiguration.enableColumnarShuffle +import org.apache.spark.sql.execution.adaptive.ock.common.StringPrefix.SHUFFLE_PREFIX +import org.apache.spark.sql.execution.adaptive.ock.exchange._ +import org.apache.spark.sql.execution.adaptive.ock.reader._ +import org.apache.spark.sql.execution.adaptive.{CustomShuffleReaderExec, QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec + +import scala.collection.mutable + +case class OmniOpBoostTuningColumnarRule(pre: Rule[SparkPlan], post: Rule[SparkPlan]) extends ColumnarRule { + override def preColumnarTransitions: Rule[SparkPlan] = pre + + override def postColumnarTransitions: Rule[SparkPlan] = post +} + +object OmniOpBoostTuningColumnarRule { + val rollBackExchangeIdents: mutable.Set[String] = mutable.Set.empty +} + +case class OmniOpBoostTuningPreColumnarRule() extends Rule[SparkPlan] { + + override val ruleName: String = "OmniOpBoostTuningPreColumnarRule" + + val delegate: BoostTuningPreNewQueryStageRule = BoostTuningPreNewQueryStageRule() + + override def apply(plan: SparkPlan): SparkPlan = { + val executionId = getQueryExecutionId(plan) + if (executionId < 0) { + TLogWarning(s"Skipped to apply BoostTuning new query stage rule for unneeded plan: $plan") + return plan + } + + val query = BoostTuningQueryManager.getOrCreateQueryManager(executionId) + + delegate.prepareQueryExecution(query, plan) + + delegate.reportQueryShuffleMetrics(query, plan) + + tryMarkRollBack(plan) + + replaceOmniQueryExchange(plan) + } + + private def tryMarkRollBack(plan: SparkPlan): Unit = { + plan.foreach { + case plan: BoostTuningShuffleExchangeLike => + if (!enableColumnarShuffle) { + OmniOpBoostTuningColumnarRule.rollBackExchangeIdents += plan.getContext.ident + } + try { + BoostTuningColumnarShuffleExchangeExec(plan.outputPartitioning, plan.child, plan.shuffleOrigin, null).buildCheck() + } catch { + case e: UnsupportedOperationException => + logDebug(s"[OPERATOR FALLBACK] ${e} ${plan.getClass} falls back to Spark operator") + OmniOpBoostTuningColumnarRule.rollBackExchangeIdents += plan.getContext.ident + case l: UnsatisfiedLinkError => + throw l + case f: NoClassDefFoundError => + throw f + case r: RuntimeException => + logDebug(s"[OPERATOR FALLBACK] ${r} ${plan.getClass} falls back to Spark operator") + OmniOpBoostTuningColumnarRule.rollBackExchangeIdents += plan.getContext.ident + case t: Throwable => + logDebug(s"[OPERATOR FALLBACK] ${t} ${plan.getClass} falls back to Spark operator") + OmniOpBoostTuningColumnarRule.rollBackExchangeIdents += plan.getContext.ident + } + case _ => + } + } + + def replaceOmniQueryExchange(plan: SparkPlan): SparkPlan = { + plan.transformUp { + case ex: ColumnarShuffleExchangeExec => + BoostTuningColumnarShuffleExchangeExec( + ex.outputPartitioning, ex.child, ex.shuffleOrigin, + PartitionContext(normalizedSparkPlan(ex, SHUFFLE_PREFIX))) + } + } +} + +case class OmniOpBoostTuningPostColumnarRule() extends Rule[SparkPlan] { + + override val ruleName: String = "OmniOpBoostTuningPostColumnarRule" + + override def apply(plan: SparkPlan): SparkPlan = { + + var newPlan = plan match { + case b: BoostTuningShuffleExchangeLike if !OmniOpBoostTuningColumnarRule.rollBackExchangeIdents.contains(b.getContext.ident) => + b.child match { + case ColumnarToRowExec(child) => + BoostTuningColumnarShuffleExchangeExec(b.outputPartitioning, child, b.shuffleOrigin, b.getContext) + case plan if !plan.supportsColumnar => + BoostTuningColumnarShuffleExchangeExec(b.outputPartitioning, RowToOmniColumnarExec(plan), b.shuffleOrigin, b.getContext) + case _ => b + } + case _ => plan + } + + newPlan = additionalReplaceWithColumnarPlan(newPlan) + + newPlan.transformUp { + case c: CustomShuffleReaderExec if ColumnarPluginConfig.getConf.enableColumnarShuffle => + c.child match { + case shuffle: BoostTuningColumnarShuffleExchangeExec => + logDebug(s"Columnar Processing for ${c.getClass} is currently supported.") + BoostTuningColumnarCustomShuffleReaderExec(c.child, c.partitionSpecs) + case ShuffleQueryStageExec(_, shuffle: BoostTuningColumnarShuffleExchangeExec) => + logDebug(s"Columnar Processing for ${c.getClass} is currently supported.") + BoostTuningColumnarCustomShuffleReaderExec(c.child, c.partitionSpecs) + case ShuffleQueryStageExec(_, reused: ReusedExchangeExec) => + reused match { + case ReusedExchangeExec(_, shuffle: BoostTuningColumnarShuffleExchangeExec) => + logDebug(s"Columnar Processing for ${c.getClass} is currently supported.") + BoostTuningColumnarCustomShuffleReaderExec(c.child, c.partitionSpecs) + case _ => + c + } + case _ => + c + } + } + } + + def additionalReplaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match { + case ColumnarToRowExec(child: BoostTuningShuffleExchangeLike) => + additionalReplaceWithColumnarPlan(child) + case r: SparkPlan + if !r.isInstanceOf[QueryStageExec] && !r.supportsColumnar && r.children.exists(c => + c.isInstanceOf[ColumnarToRowExec]) => + val children = r.children.map { + case c: ColumnarToRowExec => + val child = additionalReplaceWithColumnarPlan(c.child) + OmniColumnarToRowExec(child) + case other => + additionalReplaceWithColumnarPlan(other) + } + r.withNewChildren(children) + case p => + val children = p.children.map(additionalReplaceWithColumnarPlan) + p.withNewChildren(children) + } +} + diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/rule/relation/ColumnarSMJRelationMarker.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/rule/relation/ColumnarSMJRelationMarker.scala new file mode 100644 index 0000000000000000000000000000000000000000..9740829d23d98581bfce55d5c06e98b3b027e88f --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/rule/relation/ColumnarSMJRelationMarker.scala @@ -0,0 +1,20 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock.rule.relation + +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.joins.{ColumnarSortMergeJoinExec, SortMergeJoinExec} + +object ColumnarSMJRelationMarker extends RelationMarker { + + override def solve(plan: SparkPlan): SparkPlan = plan.transformUp { + case csmj @ ColumnarSortMergeJoinExec(_, _, _, _, left, right, _, _) => + SMJRelationMarker.solveDepAndWorkGroupOfSMJExec(left, right) + csmj + case smj @ SortMergeJoinExec(_, _, _, _, left, right, _) => + SMJRelationMarker.solveDepAndWorkGroupOfSMJExec(left, right) + smj + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/pom.xml b/omnioperator/omniop-spark-extension-ock/pom.xml index 2d3f670bbc6407bc23bfef889f4aded7e1db108a..17c74a0ececf7dbfc80d49d4ee14a4df2625d838 100644 --- a/omnioperator/omniop-spark-extension-ock/pom.xml +++ b/omnioperator/omniop-spark-extension-ock/pom.xml @@ -4,11 +4,13 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 + com.huawei.ock + omniop-spark-extension-ock + pom + Huawei Open Computing Kit for Spark + 23.0.0 + - cpp/ - cpp/build/releases/ - FALSE - 0.6.1 3.1.2 2.12.10 2.12 @@ -18,15 +20,9 @@ spark-3.1 3.2.0 3.1.1 - 22.0.0 + 23.0.0 - com.huawei.ock - ock-omniop-shuffle-manager - jar - Huawei Open Computing Kit for Spark, shuffle manager - 22.0.0 - org.scala-lang @@ -66,12 +62,12 @@ com.huawei.boostkit boostkit-omniop-bindings - 1.1.0 + 1.3.0 com.huawei.kunpeng boostkit-omniop-spark - 3.1.1-1.1.0 + 3.1.1-1.3.0 com.huawei.ock @@ -103,103 +99,8 @@ - - - ${project.artifactId}-${project.version}-for-${input.version} - - - ${cpp.build.dir} - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - - - net.alchim31.maven - scala-maven-plugin - ${scala.plugin.version} - - all - - - - - compile - testCompile - - - - -dependencyfile - ${project.build.directory}/.scala_dependencies - - - - - - - org.apache.maven.plugins - maven-jar-plugin - - - org.apache.maven.plugins - maven-compiler-plugin - 3.1 - - 8 - 8 - true - - -Xlint:all - - - - - exec-maven-plugin - org.codehaus.mojo - 3.0.0 - - - Build CPP - generate-resources - - exec - - - bash - - ${cpp.dir}/build.sh - ${plugin.cpp.test} - - - - - - - org.xolstice.maven.plugins - protobuf-maven-plugin - ${protobuf.maven.version} - - ${project.basedir}/../cpp/src/proto - - - - - compile - - - - - - - - - org.apache.maven.plugins - maven-jar-plugin - ${maven.plugin.version} - - - - - \ No newline at end of file + + ock-omniop-shuffle + ock-omniop-tuning + + \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/serialize/.keep b/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/serialize/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/spark/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/spark/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt index dd0b79dbabc3b9c65ff8edf8b6c34853f51a0f63..491cfb7086037229608f2963cf6c278ca132b198 100644 --- a/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt @@ -5,7 +5,7 @@ project(spark-thestral-plugin) cmake_minimum_required(VERSION 3.10) # configure cmake -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_COMPILER "g++") set(root_directory ${PROJECT_BINARY_DIR}) diff --git a/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt index dbcffef431b42391f036d1ca523f05b6fc85b459..31c50564588ee7281d2e268f24972aedab245940 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt @@ -15,7 +15,11 @@ set (SOURCE_FILES jni/SparkJniWrapper.cpp jni/OrcColumnarBatchJniReader.cpp jni/jni_common.cpp - ) + jni/ParquetColumnarBatchJniReader.cpp + tablescan/ParquetReader.cpp + io/orcfile/OrcFileRewrite.cc + hdfs/hdfs_internal.cpp + io/orcfile/HdfsFileInputStreamV2.cpp) #Find required protobuf package find_package(Protobuf REQUIRED) @@ -30,12 +34,19 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}) protobuf_generate_cpp(PROTO_SRCS_VB PROTO_HDRS_VB proto/vec_data.proto) add_library (${PROJ_TARGET} SHARED ${SOURCE_FILES} ${PROTO_SRCS} ${PROTO_HDRS} ${PROTO_SRCS_VB} ${PROTO_HDRS_VB}) +find_package(Arrow REQUIRED) +find_package(ArrowDataset REQUIRED) +find_package(Parquet REQUIRED) + #JNI target_include_directories(${PROJ_TARGET} PUBLIC $ENV{JAVA_HOME}/include) target_include_directories(${PROJ_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux) target_include_directories(${PROJ_TARGET} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) target_link_libraries (${PROJ_TARGET} PUBLIC + Arrow::arrow_shared + ArrowDataset::arrow_dataset_shared + Parquet::parquet_shared orc crypto sasl2 @@ -44,8 +55,7 @@ target_link_libraries (${PROJ_TARGET} PUBLIC snappy lz4 zstd - boostkit-omniop-runtime-1.1.0-aarch64 - boostkit-omniop-vector-1.1.0-aarch64 + boostkit-omniop-vector-1.3.0-aarch64 ) set_target_properties(${PROJ_TARGET} PROPERTIES diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/common.cpp b/omnioperator/omniop-spark-extension/cpp/src/common/common.cpp index 2c6b9fab89ec31c7df596cc4e9b14e3f869a12b2..f33d5c4c9df9695c2464b622587dea9e3546c39c 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/common/common.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/common/common.cpp @@ -76,21 +76,4 @@ spark::CompressionKind GetCompressionType(const std::string& name) { int IsFileExist(const std::string path) { return !access(path.c_str(), F_OK); -} - -void ReleaseVectorBatch(omniruntime::vec::VectorBatch& vb) -{ - int tmpVectorNum = vb.GetVectorCount(); - std::set vectorBatchAddresses; - vectorBatchAddresses.clear(); - for (int vecIndex = 0; vecIndex < tmpVectorNum; ++vecIndex) { - vectorBatchAddresses.insert(vb.GetVector(vecIndex)); - } - for (Vector * tmpAddress : vectorBatchAddresses) { - if (nullptr == tmpAddress) { - throw std::runtime_error("delete nullptr error for release vectorBatch"); - } - delete tmpAddress; - } - vectorBatchAddresses.clear(); } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/common.h b/omnioperator/omniop-spark-extension/cpp/src/common/common.h index fdc3b10e692e3944eeee9cf70f96ed47262a5e77..733dac920727489b205727d32300252bd32626c5 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/common/common.h +++ b/omnioperator/omniop-spark-extension/cpp/src/common/common.h @@ -45,6 +45,4 @@ spark::CompressionKind GetCompressionType(const std::string& name); int IsFileExist(const std::string path); -void ReleaseVectorBatch(omniruntime::vec::VectorBatch& vb); - #endif //CPP_COMMON_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/hdfs/hdfs_internal.cpp b/omnioperator/omniop-spark-extension/cpp/src/hdfs/hdfs_internal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3c10e5b8a01906039d26c3a9ae1fc1a9f3984044 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/hdfs/hdfs_internal.cpp @@ -0,0 +1,74 @@ +// +// Created by l00451143 on 2023/11/27. +// + +#include "hdfs_internal.h" +#include + +using namespace orc; + +LibHdfsShim::LibHdfsShim() { + // std::cout << "Create to new hdfs filesystem"<< std::endl; +} + +LibHdfsShim::~LibHdfsShim() { + // std::cout << "Begin to release hdfs filesystem"<< std::endl; + if (fs_ != nullptr){ + this->Disconnect(); + } + if (fs_ != nullptr && file_ != nullptr){ + this->CloseFile(); + } + // std::cout << "End to release hdfs filesystem"<< std::endl; +} + +StatusCode LibHdfsShim::Connect(const char *url, tPort port) { + // std::string urlStr(url); + // std::cout << "url: " << urlStr << ", port: " << port << std::endl; + this->fs_= hdfsConnect(url, port); + if (!fs_) { + // std::cout << "Fail to connect filesystem"<< std::endl; + return StatusCode::FSConnectError; + } + return StatusCode::OK; +} + +StatusCode LibHdfsShim::OpenFile(const char *path, int bufferSize, short replication, + int32_t blocksize) { + // std::string pathStr(path); + // std::cout << "path: " << pathStr << ", bufferSize: " << bufferSize << ", replication: " << replication << ", blocksize: " << blocksize << std::endl; + this->file_ = hdfsOpenFile(this->fs_, path, O_RDONLY, bufferSize, replication, blocksize); + if (!file_) { + // std::cout << "Fail to open file"<< std::endl; + this->Disconnect(); + return StatusCode::OpenFileError; + } + return StatusCode::OK; +} + +int LibHdfsShim::GetFileSize(const char *path) { + // std::string pathStr(path); + // std::cout << "path: " << pathStr << std::endl; + hdfsFileInfo* fileInfo = hdfsGetPathInfo(this->fs_, path); + if (!fileInfo){ + std::cout << "Fail to get path info"<< std::endl; + }else{ + // std::string fileName(fileInfo->mName); + // std::cout << "Success get path info, size: " << fileInfo->mSize << ", fileName: " << fileName << std::endl; + } + return fileInfo->mSize; +} + +int32_t LibHdfsShim::Read(void *buffer, int32_t length, int64_t offset) { + return hdfsPread(this->fs_, this->file_, offset, buffer, length); +} + +int LibHdfsShim::CloseFile() { + // std::cout << "Close hdfs filesystem"<< std::endl; + return hdfsCloseFile(this->fs_, this->file_); +} + +int LibHdfsShim::Disconnect() { + // std::cout << "Disconnect hdfs filesystem"<< std::endl; + return hdfsDisconnect(this->fs_); +} diff --git a/omnioperator/omniop-spark-extension/cpp/src/hdfs/hdfs_internal.h b/omnioperator/omniop-spark-extension/cpp/src/hdfs/hdfs_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..be153c1f307afdd7f429934bdec35cd1f8500306 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/hdfs/hdfs_internal.h @@ -0,0 +1,38 @@ +// +// Created by l00451143 on 2023/11/27. +// + +#ifndef SPARK_THESTRAL_PLUGIN_HDFS_INTERNAL_H +#define SPARK_THESTRAL_PLUGIN_HDFS_INTERNAL_H + +#endif //SPARK_THESTRAL_PLUGIN_HDFS_INTERNAL_H + +#include "include/hdfs.h" +#include "status.h" + +namespace orc { + +class LibHdfsShim { +public: + LibHdfsShim(); + ~LibHdfsShim(); + + // return hdfsFS + StatusCode Connect(const char* url, tPort port); + // return hdfsFile + StatusCode OpenFile(const char* path, int bufferSize, short replication, int32_t blocksize); + // return tSize + int32_t Read( void* buffer, int32_t length, int64_t offset); + + int GetFileSize(const char* path); + +private: + hdfsFS fs_; + hdfsFile file_; + + int CloseFile(); + + int Disconnect(); +}; + +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/hdfs/status.h b/omnioperator/omniop-spark-extension/cpp/src/hdfs/status.h new file mode 100644 index 0000000000000000000000000000000000000000..185f9870c21d42c6a73d2723a3617097ce86ace4 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/hdfs/status.h @@ -0,0 +1,22 @@ +// +// Created by l00451143 on 2023/11/27. +// + +#ifndef SPARK_THESTRAL_PLUGIN_STATUS_H +#define SPARK_THESTRAL_PLUGIN_STATUS_H + +#endif //SPARK_THESTRAL_PLUGIN_STATUS_H +namespace orc { + + enum StatusCode : char { + OK = 0, + FSConnectError = 1, + OpenFileError = 2, + ReadFileError = 3, + InfoFileError = 4 + }; + class Status { + public: + static bool ok(StatusCode code) { return code == OK; } + }; +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/include/hdfs.h b/omnioperator/omniop-spark-extension/cpp/src/include/hdfs.h new file mode 100644 index 0000000000000000000000000000000000000000..b8f47dbe14326e93f9cc55474b6a63600bce711f --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/include/hdfs.h @@ -0,0 +1,1086 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBHDFS_HDFS_H +#define LIBHDFS_HDFS_H + +#include /* for EINTERNAL, etc. */ +#include /* for O_RDONLY, O_WRONLY */ +#include /* for uint64_t, etc. */ +#include /* for time_t */ + +/* + * Support export of DLL symbols during libhdfs build, and import of DLL symbols + * during client application build. A client application may optionally define + * symbol LIBHDFS_DLL_IMPORT in its build. This is not strictly required, but + * the compiler can produce more efficient code with it. + */ +#ifdef WIN32 + #ifdef LIBHDFS_DLL_EXPORT + #define LIBHDFS_EXTERNAL __declspec(dllexport) + #elif LIBHDFS_DLL_IMPORT + #define LIBHDFS_EXTERNAL __declspec(dllimport) + #else + #define LIBHDFS_EXTERNAL + #endif +#else + #ifdef LIBHDFS_DLL_EXPORT + #define LIBHDFS_EXTERNAL __attribute__((visibility("default"))) + #elif LIBHDFS_DLL_IMPORT + #define LIBHDFS_EXTERNAL __attribute__((visibility("default"))) + #else + #define LIBHDFS_EXTERNAL + #endif +#endif + +#ifndef O_RDONLY +#define O_RDONLY 1 +#endif + +#ifndef O_WRONLY +#define O_WRONLY 2 +#endif + +#ifndef EINTERNAL +#define EINTERNAL 255 +#endif + +#define ELASTIC_BYTE_BUFFER_POOL_CLASS \ + "org/apache/hadoop/io/ElasticByteBufferPool" + +/** All APIs set errno to meaningful values */ + +#ifdef __cplusplus +extern "C" { +#endif + /** + * Some utility decls used in libhdfs. + */ + struct hdfsBuilder; + typedef int32_t tSize; /// size of data for read/write io ops + typedef time_t tTime; /// time type in seconds + typedef int64_t tOffset;/// offset within the file + typedef uint16_t tPort; /// port + typedef enum tObjectKind { + kObjectKindFile = 'F', + kObjectKindDirectory = 'D', + } tObjectKind; + struct hdfsStreamBuilder; + + + /** + * The C reflection of org.apache.org.hadoop.FileSystem . + */ + struct hdfs_internal; + typedef struct hdfs_internal* hdfsFS; + + struct hdfsFile_internal; + typedef struct hdfsFile_internal* hdfsFile; + + struct hadoopRzOptions; + + struct hadoopRzBuffer; + + /** + * Determine if a file is open for read. + * + * @param file The HDFS file + * @return 1 if the file is open for read; 0 otherwise + */ + LIBHDFS_EXTERNAL + int hdfsFileIsOpenForRead(hdfsFile file); + + /** + * Determine if a file is open for write. + * + * @param file The HDFS file + * @return 1 if the file is open for write; 0 otherwise + */ + LIBHDFS_EXTERNAL + int hdfsFileIsOpenForWrite(hdfsFile file); + + struct hdfsReadStatistics { + uint64_t totalBytesRead; + uint64_t totalLocalBytesRead; + uint64_t totalShortCircuitBytesRead; + uint64_t totalZeroCopyBytesRead; + }; + + /** + * Get read statistics about a file. This is only applicable to files + * opened for reading. + * + * @param file The HDFS file + * @param stats (out parameter) on a successful return, the read + * statistics. Unchanged otherwise. You must free the + * returned statistics with hdfsFileFreeReadStatistics. + * @return 0 if the statistics were successfully returned, + * -1 otherwise. On a failure, please check errno against + * ENOTSUP. webhdfs, LocalFilesystem, and so forth may + * not support read statistics. + */ + LIBHDFS_EXTERNAL + int hdfsFileGetReadStatistics(hdfsFile file, + struct hdfsReadStatistics **stats); + + /** + * @param stats HDFS read statistics for a file. + * + * @return the number of remote bytes read. + */ + LIBHDFS_EXTERNAL + int64_t hdfsReadStatisticsGetRemoteBytesRead( + const struct hdfsReadStatistics *stats); + + /** + * Clear the read statistics for a file. + * + * @param file The file to clear the read statistics of. + * + * @return 0 on success; the error code otherwise. + * EINVAL: the file is not open for reading. + * ENOTSUP: the file does not support clearing the read + * statistics. + * Errno will also be set to this code on failure. + */ + LIBHDFS_EXTERNAL + int hdfsFileClearReadStatistics(hdfsFile file); + + /** + * Free some HDFS read statistics. + * + * @param stats The HDFS read statistics to free. + */ + LIBHDFS_EXTERNAL + void hdfsFileFreeReadStatistics(struct hdfsReadStatistics *stats); + + struct hdfsHedgedReadMetrics { + uint64_t hedgedReadOps; + uint64_t hedgedReadOpsWin; + uint64_t hedgedReadOpsInCurThread; + }; + + /** + * Get cluster wide hedged read metrics. + * + * @param fs The configured filesystem handle + * @param metrics (out parameter) on a successful return, the hedged read + * metrics. Unchanged otherwise. You must free the returned + * statistics with hdfsFreeHedgedReadMetrics. + * @return 0 if the metrics were successfully returned, -1 otherwise. + * On a failure, please check errno against + * ENOTSUP. webhdfs, LocalFilesystem, and so forth may + * not support hedged read metrics. + */ + LIBHDFS_EXTERNAL + int hdfsGetHedgedReadMetrics(hdfsFS fs, struct hdfsHedgedReadMetrics **metrics); + + /** + * Free HDFS Hedged read metrics. + * + * @param metrics The HDFS Hedged read metrics to free + */ + LIBHDFS_EXTERNAL + void hdfsFreeHedgedReadMetrics(struct hdfsHedgedReadMetrics *metrics); + + /** + * hdfsConnectAsUser - Connect to a hdfstest file system as a specific user + * Connect to the hdfstest. + * @param nn The NameNode. See hdfsBuilderSetNameNode for details. + * @param port The port on which the server is listening. + * @param user the user name (this is hadoop domain user). Or NULL is equivelant to hhdfsConnect(host, port) + * @return Returns a handle to the filesystem or NULL on error. + * @deprecated Use hdfsBuilderConnect instead. + */ + LIBHDFS_EXTERNAL + hdfsFS hdfsConnectAsUser(const char* nn, tPort port, const char *user); + + /** + * Connect - Connect to a hdfstest file system. + * Connect to the hdfstest. + * @param nn The NameNode. See hdfsBuilderSetNameNode for details. + * @param port The port on which the server is listening. + * @return Returns a handle to the filesystem or NULL on error. + * @deprecated Use hdfsBuilderConnect instead. + */ + LIBHDFS_EXTERNAL + hdfsFS hdfsConnect(const char* nn, tPort port); + + /** + * Connect - Connect to an hdfstest file system. + * + * Forces a new instance to be created + * + * @param nn The NameNode. See hdfsBuilderSetNameNode for details. + * @param port The port on which the server is listening. + * @param user The user name to use when connecting + * @return Returns a handle to the filesystem or NULL on error. + * @deprecated Use hdfsBuilderConnect instead. + */ + LIBHDFS_EXTERNAL + hdfsFS hdfsConnectAsUserNewInstance(const char* nn, tPort port, const char *user ); + + /** + * Connect - Connect to an hdfstest file system. + * + * Forces a new instance to be created + * + * @param nn The NameNode. See hdfsBuilderSetNameNode for details. + * @param port The port on which the server is listening. + * @return Returns a handle to the filesystem or NULL on error. + * @deprecated Use hdfsBuilderConnect instead. + */ + LIBHDFS_EXTERNAL + hdfsFS hdfsConnectNewInstance(const char* nn, tPort port); + + /** + * Connect to HDFS using the parameters defined by the builder. + * + * The HDFS builder will be freed, whether or not the connection was + * successful. + * + * Every successful call to hdfsBuilderConnect should be matched with a call + * to Disconnect, when the hdfsFS is no longer needed. + * + * @param bld The HDFS builder + * @return Returns a handle to the filesystem, or NULL on error. + */ + LIBHDFS_EXTERNAL + hdfsFS hdfsBuilderConnect(struct hdfsBuilder *bld); + + /** + * Create an HDFS builder. + * + * @return The HDFS builder, or NULL on error. + */ + LIBHDFS_EXTERNAL + struct hdfsBuilder *hdfsNewBuilder(void); + + /** + * Force the builder to always create a new instance of the FileSystem, + * rather than possibly finding one in the cache. + * + * @param bld The HDFS builder + */ + LIBHDFS_EXTERNAL + void hdfsBuilderSetForceNewInstance(struct hdfsBuilder *bld); + + /** + * Set the HDFS NameNode to connect to. + * + * @param bld The HDFS builder + * @param nn The NameNode to use. + * + * If the string given is 'default', the default NameNode + * configuration will be used (from the XML configuration files) + * + * If NULL is given, a LocalFileSystem will be created. + * + * If the string starts with a protocol type such as file:// or + * hdfstest://, this protocol type will be used. If not, the + * hdfstest:// protocol type will be used. + * + * You may specify a NameNode port in the usual way by + * passing a string of the format hdfstest://:. + * Alternately, you may set the port with + * hdfsBuilderSetNameNodePort. However, you must not pass the + * port in two different ways. + */ + LIBHDFS_EXTERNAL + void hdfsBuilderSetNameNode(struct hdfsBuilder *bld, const char *nn); + + /** + * Set the port of the HDFS NameNode to connect to. + * + * @param bld The HDFS builder + * @param port The port. + */ + LIBHDFS_EXTERNAL + void hdfsBuilderSetNameNodePort(struct hdfsBuilder *bld, tPort port); + + /** + * Set the username to use when connecting to the HDFS cluster. + * + * @param bld The HDFS builder + * @param userName The user name. The string will be shallow-copied. + */ + LIBHDFS_EXTERNAL + void hdfsBuilderSetUserName(struct hdfsBuilder *bld, const char *userName); + + /** + * Set the path to the Kerberos ticket cache to use when connecting to + * the HDFS cluster. + * + * @param bld The HDFS builder + * @param kerbTicketCachePath The Kerberos ticket cache path. The string + * will be shallow-copied. + */ + LIBHDFS_EXTERNAL + void hdfsBuilderSetKerbTicketCachePath(struct hdfsBuilder *bld, + const char *kerbTicketCachePath); + + /** + * Free an HDFS builder. + * + * It is normally not necessary to call this function since + * hdfsBuilderConnect frees the builder. + * + * @param bld The HDFS builder + */ + LIBHDFS_EXTERNAL + void hdfsFreeBuilder(struct hdfsBuilder *bld); + + /** + * Set a configuration string for an HdfsBuilder. + * + * @param key The key to set. + * @param val The value, or NULL to set no value. + * This will be shallow-copied. You are responsible for + * ensuring that it remains valid until the builder is + * freed. + * + * @return 0 on success; nonzero error code otherwise. + */ + LIBHDFS_EXTERNAL + int hdfsBuilderConfSetStr(struct hdfsBuilder *bld, const char *key, + const char *val); + + /** + * Get a configuration string. + * + * @param key The key to find + * @param val (out param) The value. This will be set to NULL if the + * key isn't found. You must free this string with + * hdfsConfStrFree. + * + * @return 0 on success; nonzero error code otherwise. + * Failure to find the key is not an error. + */ + LIBHDFS_EXTERNAL + int hdfsConfGetStr(const char *key, char **val); + + /** + * Get a configuration integer. + * + * @param key The key to find + * @param val (out param) The value. This will NOT be changed if the + * key isn't found. + * + * @return 0 on success; nonzero error code otherwise. + * Failure to find the key is not an error. + */ + LIBHDFS_EXTERNAL + int hdfsConfGetInt(const char *key, int32_t *val); + + /** + * Free a configuration string found with hdfsConfGetStr. + * + * @param val A configuration string obtained from hdfsConfGetStr + */ + LIBHDFS_EXTERNAL + void hdfsConfStrFree(char *val); + + /** + * Disconnect - Disconnect from the hdfstest file system. + * Disconnect from hdfstest. + * @param fs The configured filesystem handle. + * @return Returns 0 on success, -1 on error. + * Even if there is an error, the resources associated with the + * hdfsFS will be freed. + */ + LIBHDFS_EXTERNAL + int hdfsDisconnect(hdfsFS fs); + + /** + * OpenFile - Open a hdfstest file in given mode. + * @deprecated Use the hdfsStreamBuilder functions instead. + * This function does not support setting block sizes bigger than 2 GB. + * + * @param fs The configured filesystem handle. + * @param path The full path to the file. + * @param flags - an | of bits/fcntl.h file flags - supported flags are O_RDONLY, O_WRONLY (meaning create or overwrite i.e., implies O_TRUNCAT), + * O_WRONLY|O_APPEND. Other flags are generally ignored other than (O_RDWR || (O_EXCL & O_CREAT)) which return NULL and set errno equal ENOTSUP. + * @param bufferSize Size of buffer for read/write - pass 0 if you want + * to use the default configured values. + * @param replication Block replication - pass 0 if you want to use + * the default configured values. + * @param blocksize Size of block - pass 0 if you want to use the + * default configured values. Note that if you want a block size bigger + * than 2 GB, you must use the hdfsStreamBuilder API rather than this + * deprecated function. + * @return Returns the handle to the open file or NULL on error. + */ + LIBHDFS_EXTERNAL + hdfsFile hdfsOpenFile(hdfsFS fs, const char* path, int flags, + int bufferSize, short replication, tSize blocksize); + + /** + * hdfsStreamBuilderAlloc - Allocate an HDFS stream builder. + * + * @param fs The configured filesystem handle. + * @param path The full path to the file. Will be deep-copied. + * @param flags The open flags, as in OpenFile. + * @return Returns the hdfsStreamBuilder, or NULL on error. + */ + LIBHDFS_EXTERNAL + struct hdfsStreamBuilder *hdfsStreamBuilderAlloc(hdfsFS fs, + const char *path, int flags); + + /** + * hdfsStreamBuilderFree - Free an HDFS file builder. + * + * It is normally not necessary to call this function since + * hdfsStreamBuilderBuild frees the builder. + * + * @param bld The hdfsStreamBuilder to free. + */ + LIBHDFS_EXTERNAL + void hdfsStreamBuilderFree(struct hdfsStreamBuilder *bld); + + /** + * hdfsStreamBuilderSetBufferSize - Set the stream buffer size. + * + * @param bld The hdfstest stream builder. + * @param bufferSize The buffer size to set. + * + * @return 0 on success, or -1 on error. Errno will be set on error. + */ + LIBHDFS_EXTERNAL + int hdfsStreamBuilderSetBufferSize(struct hdfsStreamBuilder *bld, + int32_t bufferSize); + + /** + * hdfsStreamBuilderSetReplication - Set the replication for the stream. + * This is only relevant for output streams, which will create new blocks. + * + * @param bld The hdfstest stream builder. + * @param replication The replication to set. + * + * @return 0 on success, or -1 on error. Errno will be set on error. + * If you call this on an input stream builder, you will get + * EINVAL, because this configuration is not relevant to input + * streams. + */ + LIBHDFS_EXTERNAL + int hdfsStreamBuilderSetReplication(struct hdfsStreamBuilder *bld, + int16_t replication); + + /** + * hdfsStreamBuilderSetDefaultBlockSize - Set the default block size for + * the stream. This is only relevant for output streams, which will create + * new blocks. + * + * @param bld The hdfstest stream builder. + * @param defaultBlockSize The default block size to set. + * + * @return 0 on success, or -1 on error. Errno will be set on error. + * If you call this on an input stream builder, you will get + * EINVAL, because this configuration is not relevant to input + * streams. + */ + LIBHDFS_EXTERNAL + int hdfsStreamBuilderSetDefaultBlockSize(struct hdfsStreamBuilder *bld, + int64_t defaultBlockSize); + + /** + * hdfsStreamBuilderBuild - Build the stream by calling open or create. + * + * @param bld The hdfstest stream builder. This pointer will be freed, whether + * or not the open succeeds. + * + * @return the stream pointer on success, or NULL on error. Errno will be + * set on error. + */ + LIBHDFS_EXTERNAL + hdfsFile hdfsStreamBuilderBuild(struct hdfsStreamBuilder *bld); + + /** + * hdfsTruncateFile - Truncate a hdfstest file to given lenght. + * @param fs The configured filesystem handle. + * @param path The full path to the file. + * @param newlength The size the file is to be truncated to + * @return 1 if the file has been truncated to the desired newlength + * and is immediately available to be reused for write operations + * such as append. + * 0 if a background process of adjusting the length of the last + * block has been started, and clients should wait for it to + * complete before proceeding with further file updates. + * -1 on error. + */ + LIBHDFS_EXTERNAL + int hdfsTruncateFile(hdfsFS fs, const char* path, tOffset newlength); + + /** + * hdfsUnbufferFile - Reduce the buffering done on a file. + * + * @param file The file to unbuffer. + * @return 0 on success + * ENOTSUP if the file does not support unbuffering + * Errno will also be set to this value. + */ + LIBHDFS_EXTERNAL + int hdfsUnbufferFile(hdfsFile file); + + /** + * CloseFile - Close an open file. + * @param fs The configured filesystem handle. + * @param file The file handle. + * @return Returns 0 on success, -1 on error. + * On error, errno will be set appropriately. + * If the hdfstest file was valid, the memory associated with it will + * be freed at the end of this call, even if there was an I/O + * error. + */ + LIBHDFS_EXTERNAL + int hdfsCloseFile(hdfsFS fs, hdfsFile file); + + + /** + * hdfsExists - Checks if a given path exsits on the filesystem + * @param fs The configured filesystem handle. + * @param path The path to look for + * @return Returns 0 on success, -1 on error. + */ + LIBHDFS_EXTERNAL + int hdfsExists(hdfsFS fs, const char *path); + + + /** + * hdfsSeek - Seek to given offset in file. + * This works only for files opened in read-only mode. + * @param fs The configured filesystem handle. + * @param file The file handle. + * @param desiredPos Offset into the file to seek into. + * @return Returns 0 on success, -1 on error. + */ + LIBHDFS_EXTERNAL + int hdfsSeek(hdfsFS fs, hdfsFile file, tOffset desiredPos); + + + /** + * hdfsTell - Get the current offset in the file, in bytes. + * @param fs The configured filesystem handle. + * @param file The file handle. + * @return Current offset, -1 on error. + */ + LIBHDFS_EXTERNAL + tOffset hdfsTell(hdfsFS fs, hdfsFile file); + + + /** + * Read - Read data from an open file. + * @param fs The configured filesystem handle. + * @param file The file handle. + * @param buffer The buffer to copy read bytes into. + * @param length The length of the buffer. + * @return On success, a positive number indicating how many bytes + * were read. + * On end-of-file, 0. + * On error, -1. Errno will be set to the error code. + * Just like the POSIX read function, Read will return -1 + * and set errno to EINTR if data is temporarily unavailable, + * but we are not yet at the end of the file. + */ + LIBHDFS_EXTERNAL + tSize hdfsRead(hdfsFS fs, hdfsFile file, void* buffer, tSize length); + + /** + * hdfsPread - Positional read of data from an open file. + * @param fs The configured filesystem handle. + * @param file The file handle. + * @param position Position from which to read + * @param buffer The buffer to copy read bytes into. + * @param length The length of the buffer. + * @return See Read + */ + LIBHDFS_EXTERNAL + tSize hdfsPread(hdfsFS fs, hdfsFile file, tOffset position, + void* buffer, tSize length); + + + /** + * hdfsWrite - Write data into an open file. + * @param fs The configured filesystem handle. + * @param file The file handle. + * @param buffer The data. + * @param length The no. of bytes to write. + * @return Returns the number of bytes written, -1 on error. + */ + LIBHDFS_EXTERNAL + tSize hdfsWrite(hdfsFS fs, hdfsFile file, const void* buffer, + tSize length); + + + /** + * hdfsWrite - Flush the data. + * @param fs The configured filesystem handle. + * @param file The file handle. + * @return Returns 0 on success, -1 on error. + */ + LIBHDFS_EXTERNAL + int hdfsFlush(hdfsFS fs, hdfsFile file); + + + /** + * hdfsHFlush - Flush out the data in client's user buffer. After the + * return of this call, new readers will see the data. + * @param fs configured filesystem handle + * @param file file handle + * @return 0 on success, -1 on error and sets errno + */ + LIBHDFS_EXTERNAL + int hdfsHFlush(hdfsFS fs, hdfsFile file); + + + /** + * hdfsHSync - Similar to posix fsync, Flush out the data in client's + * user buffer. all the way to the disk device (but the disk may have + * it in its cache). + * @param fs configured filesystem handle + * @param file file handle + * @return 0 on success, -1 on error and sets errno + */ + LIBHDFS_EXTERNAL + int hdfsHSync(hdfsFS fs, hdfsFile file); + + + /** + * hdfsAvailable - Number of bytes that can be read from this + * input stream without blocking. + * @param fs The configured filesystem handle. + * @param file The file handle. + * @return Returns available bytes; -1 on error. + */ + LIBHDFS_EXTERNAL + int hdfsAvailable(hdfsFS fs, hdfsFile file); + + + /** + * hdfsCopy - Copy file from one filesystem to another. + * @param srcFS The handle to source filesystem. + * @param src The path of source file. + * @param dstFS The handle to destination filesystem. + * @param dst The path of destination file. + * @return Returns 0 on success, -1 on error. + */ + LIBHDFS_EXTERNAL + int hdfsCopy(hdfsFS srcFS, const char* src, hdfsFS dstFS, const char* dst); + + + /** + * hdfsMove - Move file from one filesystem to another. + * @param srcFS The handle to source filesystem. + * @param src The path of source file. + * @param dstFS The handle to destination filesystem. + * @param dst The path of destination file. + * @return Returns 0 on success, -1 on error. + */ + LIBHDFS_EXTERNAL + int hdfsMove(hdfsFS srcFS, const char* src, hdfsFS dstFS, const char* dst); + + + /** + * hdfsDelete - Delete file. + * @param fs The configured filesystem handle. + * @param path The path of the file. + * @param recursive if path is a directory and set to + * non-zero, the directory is deleted else throws an exception. In + * case of a file the recursive argument is irrelevant. + * @return Returns 0 on success, -1 on error. + */ + LIBHDFS_EXTERNAL + int hdfsDelete(hdfsFS fs, const char* path, int recursive); + + /** + * hdfsRename - Rename file. + * @param fs The configured filesystem handle. + * @param oldPath The path of the source file. + * @param newPath The path of the destination file. + * @return Returns 0 on success, -1 on error. + */ + LIBHDFS_EXTERNAL + int hdfsRename(hdfsFS fs, const char* oldPath, const char* newPath); + + + /** + * hdfsGetWorkingDirectory - Get the current working directory for + * the given filesystem. + * @param fs The configured filesystem handle. + * @param buffer The user-buffer to copy path of cwd into. + * @param bufferSize The length of user-buffer. + * @return Returns buffer, NULL on error. + */ + LIBHDFS_EXTERNAL + char* hdfsGetWorkingDirectory(hdfsFS fs, char *buffer, size_t bufferSize); + + + /** + * hdfsSetWorkingDirectory - Set the working directory. All relative + * paths will be resolved relative to it. + * @param fs The configured filesystem handle. + * @param path The path of the new 'cwd'. + * @return Returns 0 on success, -1 on error. + */ + LIBHDFS_EXTERNAL + int hdfsSetWorkingDirectory(hdfsFS fs, const char* path); + + + /** + * hdfsCreateDirectory - Make the given file and all non-existent + * parents into directories. + * @param fs The configured filesystem handle. + * @param path The path of the directory. + * @return Returns 0 on success, -1 on error. + */ + LIBHDFS_EXTERNAL + int hdfsCreateDirectory(hdfsFS fs, const char* path); + + + /** + * hdfsSetReplication - Set the replication of the specified + * file to the supplied value + * @param fs The configured filesystem handle. + * @param path The path of the file. + * @return Returns 0 on success, -1 on error. + */ + LIBHDFS_EXTERNAL + int hdfsSetReplication(hdfsFS fs, const char* path, int16_t replication); + + + /** + * hdfsFileInfo - Information about a file/directory. + */ + typedef struct { + tObjectKind mKind; /* file or directory */ + char *mName; /* the name of the file */ + tTime mLastMod; /* the last modification time for the file in seconds */ + tOffset mSize; /* the size of the file in bytes */ + short mReplication; /* the count of replicas */ + tOffset mBlockSize; /* the block size for the file */ + char *mOwner; /* the owner of the file */ + char *mGroup; /* the group associated with the file */ + short mPermissions; /* the permissions associated with the file */ + tTime mLastAccess; /* the last access time for the file in seconds */ + } hdfsFileInfo; + + + /** + * hdfsListDirectory - Get list of files/directories for a given + * directory-path. hdfsFreeFileInfo should be called to deallocate memory. + * @param fs The configured filesystem handle. + * @param path The path of the directory. + * @param numEntries Set to the number of files/directories in path. + * @return Returns a dynamically-allocated array of hdfsFileInfo + * objects; NULL on error or empty directory. + * errno is set to non-zero on error or zero on success. + */ + LIBHDFS_EXTERNAL + hdfsFileInfo *hdfsListDirectory(hdfsFS fs, const char* path, + int *numEntries); + + + /** + * hdfsGetPathInfo - Get information about a path as a (dynamically + * allocated) single hdfsFileInfo struct. hdfsFreeFileInfo should be + * called when the pointer is no longer needed. + * @param fs The configured filesystem handle. + * @param path The path of the file. + * @return Returns a dynamically-allocated hdfsFileInfo object; + * NULL on error. + */ + LIBHDFS_EXTERNAL + hdfsFileInfo *hdfsGetPathInfo(hdfsFS fs, const char* path); + + + /** + * hdfsFreeFileInfo - Free up the hdfsFileInfo array (including fields) + * @param hdfsFileInfo The array of dynamically-allocated hdfsFileInfo + * objects. + * @param numEntries The size of the array. + */ + LIBHDFS_EXTERNAL + void hdfsFreeFileInfo(hdfsFileInfo *hdfsFileInfo, int numEntries); + + /** + * hdfsFileIsEncrypted: determine if a file is encrypted based on its + * hdfsFileInfo. + * @return -1 if there was an error (errno will be set), 0 if the file is + * not encrypted, 1 if the file is encrypted. + */ + LIBHDFS_EXTERNAL + int hdfsFileIsEncrypted(hdfsFileInfo *hdfsFileInfo); + + + /** + * hdfsGetHosts - Get hostnames where a particular block (determined by + * pos & blocksize) of a file is stored. The last element in the array + * is NULL. Due to replication, a single block could be present on + * multiple hosts. + * @param fs The configured filesystem handle. + * @param path The path of the file. + * @param start The start of the block. + * @param length The length of the block. + * @return Returns a dynamically-allocated 2-d array of blocks-hosts; + * NULL on error. + */ + LIBHDFS_EXTERNAL + char*** hdfsGetHosts(hdfsFS fs, const char* path, + tOffset start, tOffset length); + + + /** + * hdfsFreeHosts - Free up the structure returned by hdfsGetHosts + * @param hdfsFileInfo The array of dynamically-allocated hdfsFileInfo + * objects. + * @param numEntries The size of the array. + */ + LIBHDFS_EXTERNAL + void hdfsFreeHosts(char ***blockHosts); + + + /** + * hdfsGetDefaultBlockSize - Get the default blocksize. + * + * @param fs The configured filesystem handle. + * @deprecated Use hdfsGetDefaultBlockSizeAtPath instead. + * + * @return Returns the default blocksize, or -1 on error. + */ + LIBHDFS_EXTERNAL + tOffset hdfsGetDefaultBlockSize(hdfsFS fs); + + + /** + * hdfsGetDefaultBlockSizeAtPath - Get the default blocksize at the + * filesystem indicated by a given path. + * + * @param fs The configured filesystem handle. + * @param path The given path will be used to locate the actual + * filesystem. The full path does not have to exist. + * + * @return Returns the default blocksize, or -1 on error. + */ + LIBHDFS_EXTERNAL + tOffset hdfsGetDefaultBlockSizeAtPath(hdfsFS fs, const char *path); + + + /** + * hdfsGetCapacity - Return the raw capacity of the filesystem. + * @param fs The configured filesystem handle. + * @return Returns the raw-capacity; -1 on error. + */ + LIBHDFS_EXTERNAL + tOffset hdfsGetCapacity(hdfsFS fs); + + + /** + * hdfsGetUsed - Return the total raw size of all files in the filesystem. + * @param fs The configured filesystem handle. + * @return Returns the total-size; -1 on error. + */ + LIBHDFS_EXTERNAL + tOffset hdfsGetUsed(hdfsFS fs); + + /** + * Change the user and/or group of a file or directory. + * + * @param fs The configured filesystem handle. + * @param path the path to the file or directory + * @param owner User string. Set to NULL for 'no change' + * @param group Group string. Set to NULL for 'no change' + * @return 0 on success else -1 + */ + LIBHDFS_EXTERNAL + int hdfsChown(hdfsFS fs, const char* path, const char *owner, + const char *group); + + /** + * hdfsChmod + * @param fs The configured filesystem handle. + * @param path the path to the file or directory + * @param mode the bitmask to set it to + * @return 0 on success else -1 + */ + LIBHDFS_EXTERNAL + int hdfsChmod(hdfsFS fs, const char* path, short mode); + + /** + * hdfsUtime + * @param fs The configured filesystem handle. + * @param path the path to the file or directory + * @param mtime new modification time or -1 for no change + * @param atime new access time or -1 for no change + * @return 0 on success else -1 + */ + LIBHDFS_EXTERNAL + int hdfsUtime(hdfsFS fs, const char* path, tTime mtime, tTime atime); + + /** + * Allocate a zero-copy options structure. + * + * You must free all options structures allocated with this function using + * hadoopRzOptionsFree. + * + * @return A zero-copy options structure, or NULL if one could + * not be allocated. If NULL is returned, errno will + * contain the error number. + */ + LIBHDFS_EXTERNAL + struct hadoopRzOptions *hadoopRzOptionsAlloc(void); + + /** + * Determine whether we should skip checksums in read0. + * + * @param opts The options structure. + * @param skip Nonzero to skip checksums sometimes; zero to always + * check them. + * + * @return 0 on success; -1 plus errno on failure. + */ + LIBHDFS_EXTERNAL + int hadoopRzOptionsSetSkipChecksum( + struct hadoopRzOptions *opts, int skip); + + /** + * Set the ByteBufferPool to use with read0. + * + * @param opts The options structure. + * @param className If this is NULL, we will not use any + * ByteBufferPool. If this is non-NULL, it will be + * treated as the name of the pool class to use. + * For example, you can use + * ELASTIC_BYTE_BUFFER_POOL_CLASS. + * + * @return 0 if the ByteBufferPool class was found and + * instantiated; + * -1 plus errno otherwise. + */ + LIBHDFS_EXTERNAL + int hadoopRzOptionsSetByteBufferPool( + struct hadoopRzOptions *opts, const char *className); + + /** + * Free a hadoopRzOptionsFree structure. + * + * @param opts The options structure to free. + * Any associated ByteBufferPool will also be freed. + */ + LIBHDFS_EXTERNAL + void hadoopRzOptionsFree(struct hadoopRzOptions *opts); + + /** + * Perform a byte buffer read. + * If possible, this will be a zero-copy (mmap) read. + * + * @param file The file to read from. + * @param opts An options structure created by hadoopRzOptionsAlloc. + * @param maxLength The maximum length to read. We may read fewer bytes + * than this length. + * + * @return On success, we will return a new hadoopRzBuffer. + * This buffer will continue to be valid and readable + * until it is released by readZeroBufferFree. Failure to + * release a buffer will lead to a memory leak. + * You can access the data within the hadoopRzBuffer with + * hadoopRzBufferGet. If you have reached EOF, the data + * within the hadoopRzBuffer will be NULL. You must still + * free hadoopRzBuffer instances containing NULL. + * + * On failure, we will return NULL plus an errno code. + * errno = EOPNOTSUPP indicates that we could not do a + * zero-copy read, and there was no ByteBufferPool + * supplied. + */ + LIBHDFS_EXTERNAL + struct hadoopRzBuffer* hadoopReadZero(hdfsFile file, + struct hadoopRzOptions *opts, int32_t maxLength); + + /** + * Determine the length of the buffer returned from readZero. + * + * @param buffer a buffer returned from readZero. + * @return the length of the buffer. + */ + LIBHDFS_EXTERNAL + int32_t hadoopRzBufferLength(const struct hadoopRzBuffer *buffer); + + /** + * Get a pointer to the raw buffer returned from readZero. + * + * To find out how many bytes this buffer contains, call + * hadoopRzBufferLength. + * + * @param buffer a buffer returned from readZero. + * @return a pointer to the start of the buffer. This will be + * NULL when end-of-file has been reached. + */ + LIBHDFS_EXTERNAL + const void *hadoopRzBufferGet(const struct hadoopRzBuffer *buffer); + + /** + * Release a buffer obtained through readZero. + * + * @param file The hdfstest stream that created this buffer. This must be + * the same stream you called hadoopReadZero on. + * @param buffer The buffer to release. + */ + LIBHDFS_EXTERNAL + void hadoopRzBufferFree(hdfsFile file, struct hadoopRzBuffer *buffer); + + /** + * Get the last exception root cause that happened in the context of the + * current thread, i.e. the thread that called into libHDFS. + * + * The pointer returned by this function is guaranteed to be valid until + * the next call into libHDFS by the current thread. + * Users of this function should not free the pointer. + * + * A NULL will be returned if no exception information could be retrieved + * for the previous call. + * + * @return The root cause as a C-string. + */ + LIBHDFS_EXTERNAL + char* hdfsGetLastExceptionRootCause(); + + /** + * Get the last exception stack trace that happened in the context of the + * current thread, i.e. the thread that called into libHDFS. + * + * The pointer returned by this function is guaranteed to be valid until + * the next call into libHDFS by the current thread. + * Users of this function should not free the pointer. + * + * A NULL will be returned if no exception information could be retrieved + * for the previous call. + * + * @return The stack trace as a C-string. + */ + LIBHDFS_EXTERNAL + char* hdfsGetLastExceptionStackTrace(); + +#ifdef __cplusplus +} +#endif + +#undef LIBHDFS_EXTERNAL +#endif /*LIBHDFS_HDFS_H*/ + +/** + * vim: ts=4: sw=4: et + */ diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.cc b/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.cc index 51ff4b98f3eb4df234d927276e75bce1cb7bc158..3c6e3b3bc31746c7c28a2a63f5bd1b5b1b2a3e44 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.cc +++ b/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.cc @@ -24,6 +24,7 @@ #include #include #include +#include #ifdef _MSC_VER #include diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/HdfsFileInputStreamV2.cpp b/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/HdfsFileInputStreamV2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..80e626cce2a7240e3e4f8fdd5ebb3605e9646e0c --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/HdfsFileInputStreamV2.cpp @@ -0,0 +1,103 @@ +// +// Created by l00451143 on 2023/11/27. +// + +#include + +#include +#include "OrcFileRewrite.hh" +#include "hdfs/hdfs_internal.h" + + +#include "OrcFileRewrite.hh" + +namespace orc { + + class HdfsFileInputStreamV2 : public InputStream { + private: + std::string filepath_; + uint64_t total_length_; + const uint64_t READ_SIZE = 1024 * 1024; //1 MB + + std::unique_ptr file_system_; + + public: + HdfsFileInputStreamV2(std::string path) { + // std::cout << "Begin to create hdfs input steam"<< std::endl; + this->file_system_ = std::make_unique(); + + hdfs::URI uri; + try { + uri = hdfs::URI::parse_from_string(path); + } catch (const hdfs::uri_parse_error&) { + throw ParseError("Malformed URI: " + path); + } + // std::cout << "Success to parse uri, host: " << uri.get_host().c_str() + // << ", port: " << uri.get_port() + // << ", file path: " << uri.get_path() + // << std::endl; + + this->filepath_ = uri.get_path(); + + StatusCode fs_status = file_system_->Connect(uri.get_host().c_str(),static_cast(uri.get_port())); + if (fs_status != OK){ + throw ParseError("URI: " + path + ", fail to connect filesystem."); + } + // std::cout << "Success to connect hdfs file system"<< std::endl; + + StatusCode file_status = file_system_->OpenFile(filepath_.c_str(), 0, 0, 0); + if (file_status != OK){ + throw ParseError("file path: " + filepath_ + ", fail to connect filesystem."); + } + // std::cout << "Success to connect open hdfs file"<< std::endl; + + this->total_length_ = file_system_->GetFileSize(filepath_.c_str()); + // std::cout << "end to create hdfs input steam, total_length_: " << total_length_ << std::endl; + } + + ~HdfsFileInputStreamV2() override { + } + + uint64_t getLength() const override { + return this->total_length_; + } + + uint64_t getNaturalReadSize() const override { + return this->READ_SIZE; + } + + const std::string& getName() const override { + return filepath_; + } + + void read(void* buf, + uint64_t length, + uint64_t offset) override { + if (!buf) { + throw ParseError("Buffer is null"); + } + + // std::cout << "hdfs file input stream, begin read, length: " << length << ", offset: " << offset << std::endl; + + char* buf_ptr = reinterpret_cast(buf); + int32_t total_bytes_read = 0; + int32_t last_bytes_read = 0; + + do{ + last_bytes_read = file_system_->Read(buf_ptr, length - total_bytes_read, offset + total_bytes_read); + if (last_bytes_read < 0) { + // std::cout << "Fail to get read file, read bytes: " << last_bytes_read << std::endl; + throw ParseError("Error reading bytes the file."); + } + total_bytes_read += last_bytes_read; + buf_ptr += last_bytes_read; + // std::cout << "read hdfs, total_bytes_read: " << total_bytes_read << ", last_bytes_read: " << last_bytes_read << ", buf_ptr: " << buf_ptr << std::endl; + } while (total_bytes_read < length); + // std::cout << "hdfs file input stream, end read, total_bytes_read: " << total_bytes_read << ", last_bytes_read: " << last_bytes_read << std::endl; + } + }; + + std::unique_ptr readHdfsFileRewrite(const std::string& path, std::vector& tokens) { + return std::unique_ptr(new HdfsFileInputStreamV2(path)); + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/OrcFileRewrite.cc b/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/OrcFileRewrite.cc new file mode 100644 index 0000000000000000000000000000000000000000..8ec77da2ce30c96cbab5ab4f6dfd768f4648a502 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/OrcFileRewrite.cc @@ -0,0 +1,50 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OrcFileRewrite.hh" +#include "orc/Exceptions.hh" +#include "io/Adaptor.hh" + +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +#include +#define S_IRUSR _S_IREAD +#define S_IWUSR _S_IWRITE +#define stat _stat64 +#define fstat _fstat64 +#else +#include +#define O_BINARY 0 +#endif + +namespace orc { + std::unique_ptr readFileRewrite(const std::string& path, std::vector& tokens) { + if (strncmp(path.c_str(), "hdfs://", 7) == 0) { + return orc::readHdfsFileRewrite(std::string(path), tokens); + } else if (strncmp(path.c_str(), "file:", 5) == 0) { + return orc::readLocalFile(std::string(path.substr(5))); + } else { + return orc::readLocalFile(std::string(path)); + } + } +} diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/OrcFileRewrite.hh b/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/OrcFileRewrite.hh new file mode 100644 index 0000000000000000000000000000000000000000..e7bcee95cecd9dd8b0ac7a120be74a507e47d8a5 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/OrcFileRewrite.hh @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ORC_FILE_REWRITE_HH +#define ORC_FILE_REWRITE_HH + +#include + +#include "hdfspp/options.h" +#include "orc/OrcFile.hh" + +/** /file orc/OrcFile.hh + @brief The top level interface to ORC. +*/ + +namespace orc { + + /** + * Create a stream to a local file or HDFS file if path begins with "hdfs://" + * @param path the name of the file in the local file system or HDFS + */ + ORC_UNIQUE_PTR readFileRewrite(const std::string& path, std::vector& tokens); + + /** + * Create a stream to an HDFS file. + * @param path the uri of the file in HDFS + */ + ORC_UNIQUE_PTR readHdfsFileRewrite(const std::string& path, std::vector& tokens); +} + +#endif diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.cpp index 7506424fbeaa98f0a84f546662e6f4b361eeddaf..df67ac4297f0906f8b2a3346ba1798928ef42e39 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.cpp @@ -18,89 +18,20 @@ */ #include "OrcColumnarBatchJniReader.h" +#include #include "jni_common.h" using namespace omniruntime::vec; +using namespace omniruntime::type; using namespace std; using namespace orc; -jclass runtimeExceptionClass; -jclass jsonClass; -jclass arrayListClass; -jmethodID jsonMethodInt; -jmethodID jsonMethodLong; -jmethodID jsonMethodHas; -jmethodID jsonMethodString; -jmethodID jsonMethodJsonObj; -jmethodID arrayListGet; -jmethodID arrayListSize; -jmethodID jsonMethodObj; - -int initJniId(JNIEnv *env) -{ - /* - * init table scan log - */ - jsonClass = env->FindClass("org/json/JSONObject"); - arrayListClass = env->FindClass("java/util/ArrayList"); - - arrayListGet = env->GetMethodID(arrayListClass, "get", "(I)Ljava/lang/Object;"); - arrayListSize = env->GetMethodID(arrayListClass, "size", "()I"); - - // get int method - jsonMethodInt = env->GetMethodID(jsonClass, "getInt", "(Ljava/lang/String;)I"); - if (jsonMethodInt == NULL) - return -1; - - // get long method - jsonMethodLong = env->GetMethodID(jsonClass, "getLong", "(Ljava/lang/String;)J"); - if (jsonMethodLong == NULL) - return -1; - - // get has method - jsonMethodHas = env->GetMethodID(jsonClass, "has", "(Ljava/lang/String;)Z"); - if (jsonMethodHas == NULL) - return -1; - - // get string method - jsonMethodString = env->GetMethodID(jsonClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;"); - if (jsonMethodString == NULL) - return -1; - - // get json object method - jsonMethodJsonObj = env->GetMethodID(jsonClass, "getJSONObject", "(Ljava/lang/String;)Lorg/json/JSONObject;"); - if (jsonMethodJsonObj == NULL) - return -1; - - // get json object method - jsonMethodObj = env->GetMethodID(jsonClass, "get", "(Ljava/lang/String;)Ljava/lang/Object;"); - if (jsonMethodJsonObj == NULL) - return -1; - - jclass local_class = env->FindClass("Ljava/lang/RuntimeException;"); - runtimeExceptionClass = (jclass)env->NewGlobalRef(local_class); - env->DeleteLocalRef(local_class); - if (runtimeExceptionClass == NULL) - return -1; - - return 0; -} - -void JNI_OnUnload(JavaVM *vm, const void *reserved) -{ - JNIEnv *env = nullptr; - vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_8); - env->DeleteGlobalRef(runtimeExceptionClass); -} +static constexpr int32_t MAX_DECIMAL64_DIGITS = 18; JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_initializeReader(JNIEnv *env, jobject jObj, jstring path, jobject jsonObj) { JNI_FUNC_START - /* - * init logger and jni env method id - */ - initJniId(env); /* * get tailLocation from json obj @@ -121,26 +52,26 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniRe env->ReleaseStringUTFChars(serTailJstr, ptr); } - std::unique_ptr reader = createReader(orc::readFile(filePath), readerOptions); + std::vector tokens; + std::unique_ptr reader = createReader(orc::readFileRewrite(filePath, tokens), readerOptions); env->ReleaseStringUTFChars(path, pathPtr); orc::Reader *readerNew = reader.release(); return (jlong)(readerNew); JNI_FUNC_END(runtimeExceptionClass) } -bool stringToBool(string boolStr) +bool StringToBool(const std::string &boolStr) { - transform(boolStr.begin(), boolStr.end(), boolStr.begin(), ::tolower); - if (boolStr == "true") { - return true; - } else if (boolStr == "false") { - return false; + if (boost::iequals(boolStr, "true")) { + return true; + } else if (boost::iequals(boolStr, "false")) { + return false; } else { - throw std::runtime_error("Invalid input for stringToBool."); + throw std::runtime_error("Invalid input for stringToBool."); } } -int getLiteral(orc::Literal &lit, int leafType, string value) +int GetLiteral(orc::Literal &lit, int leafType, const std::string &value) { switch ((orc::PredicateDataType)leafType) { case orc::PredicateDataType::LONG: { @@ -173,7 +104,7 @@ int getLiteral(orc::Literal &lit, int leafType, string value) break; } case orc::PredicateDataType::BOOLEAN: { - lit = orc::Literal(static_cast(stringToBool(value))); + lit = orc::Literal(static_cast(StringToBool(value))); break; } default: { @@ -183,8 +114,8 @@ int getLiteral(orc::Literal &lit, int leafType, string value) return 0; } -int buildLeaves(PredicateOperatorType leafOp, vector &litList, Literal &lit, string leafNameString, PredicateDataType leafType, - SearchArgumentBuilder &builder) +int BuildLeaves(PredicateOperatorType leafOp, vector &litList, Literal &lit, const std::string &leafNameString, + PredicateDataType leafType, SearchArgumentBuilder &builder) { switch (leafOp) { case PredicateOperatorType::LESS_THAN: { @@ -234,7 +165,7 @@ int initLeaves(JNIEnv *env, SearchArgumentBuilder &builder, jobject &jsonExp, jo if (leafValue != nullptr) { std::string leafValueString(env->GetStringUTFChars(leafValue, nullptr)); if (leafValueString.size() != 0) { - getLiteral(lit, leafType, leafValueString); + GetLiteral(lit, leafType, leafValueString); } } std::vector litList; @@ -244,11 +175,11 @@ int initLeaves(JNIEnv *env, SearchArgumentBuilder &builder, jobject &jsonExp, jo for (int i = 0; i < childs; i++) { jstring child = (jstring)env->CallObjectMethod(litListValue, arrayListGet, i); std::string childString(env->GetStringUTFChars(child, nullptr)); - getLiteral(lit, leafType, childString); + GetLiteral(lit, leafType, childString); litList.push_back(lit); } } - buildLeaves((PredicateOperatorType)leafOp, litList, lit, leafNameString, (PredicateDataType)leafType, builder); + BuildLeaves((PredicateOperatorType)leafOp, litList, lit, leafNameString, (PredicateDataType)leafType, builder); return 1; } @@ -346,133 +277,225 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniRe JNI_FUNC_END(runtimeExceptionClass) } -template uint64_t copyFixwidth(orc::ColumnVectorBatch *field) +template uint64_t CopyFixedWidth(orc::ColumnVectorBatch *field) { - VectorAllocator *allocator = omniruntime::vec::GetProcessGlobalVecAllocator(); using T = typename NativeType::type; ORC_TYPE *lvb = dynamic_cast(field); - FixedWidthVector *originalVector = new FixedWidthVector(allocator, lvb->numElements); - for (uint i = 0; i < lvb->numElements; i++) { - if (lvb->notNull.data()[i]) { - originalVector->SetValue(i, (T)(lvb->data.data()[i])); - } else { - originalVector->SetValueNull(i); + auto numElements = lvb->numElements; + auto values = lvb->data.data(); + auto notNulls = lvb->notNull.data(); + auto originalVector = new Vector(numElements); + // Check ColumnVectorBatch has null or not firstly + if (lvb->hasNulls) { + for (uint i = 0; i < numElements; i++) { + if (notNulls[i]) { + originalVector->SetValue(i, (T)(values[i])); + } else { + originalVector->SetNull(i); + } + } + } else { + for (uint i = 0; i < numElements; i++) { + originalVector->SetValue(i, (T)(values[i])); + } + } + return (uint64_t)originalVector; +} + +template uint64_t CopyOptimizedForInt64(orc::ColumnVectorBatch *field) +{ + using T = typename NativeType::type; + ORC_TYPE *lvb = dynamic_cast(field); + auto numElements = lvb->numElements; + auto values = lvb->data.data(); + auto notNulls = lvb->notNull.data(); + auto originalVector = new Vector(numElements); + // Check ColumnVectorBatch has null or not firstly + if (lvb->hasNulls) { + for (uint i = 0; i < numElements; i++) { + if (!notNulls[i]) { + originalVector->SetNull(i); + } + } + } + originalVector->SetValues(0, values, numElements); + return (uint64_t)originalVector; +} + +uint64_t CopyVarWidth(orc::ColumnVectorBatch *field) +{ + orc::StringVectorBatch *lvb = dynamic_cast(field); + auto numElements = lvb->numElements; + auto values = lvb->data.data(); + auto notNulls = lvb->notNull.data(); + auto lens = lvb->length.data(); + auto originalVector = new Vector>(numElements); + if (lvb->hasNulls) { + for (uint i = 0; i < numElements; i++) { + if (notNulls[i]) { + auto data = std::string_view(reinterpret_cast(values[i]), lens[i]); + originalVector->SetValue(i, data); + } else { + originalVector->SetNull(i); + } + } + } else { + for (uint i = 0; i < numElements; i++) { + auto data = std::string_view(reinterpret_cast(values[i]), lens[i]); + originalVector->SetValue(i, data); } } return (uint64_t)originalVector; } +inline void FindLastNotEmpty(const char *chars, long &len) +{ + while (len > 0 && chars[len - 1] == ' ') { + len--; + } +} -uint64_t copyVarwidth(int maxLen, orc::ColumnVectorBatch *field, int vcType) +uint64_t CopyCharType(orc::ColumnVectorBatch *field) { - VectorAllocator *allocator = omniruntime::vec::GetProcessGlobalVecAllocator(); orc::StringVectorBatch *lvb = dynamic_cast(field); - uint64_t totalLen = - maxLen * (lvb->numElements) > lvb->getMemoryUsage() ? maxLen * (lvb->numElements) : lvb->getMemoryUsage(); - VarcharVector *originalVector = new VarcharVector(allocator, totalLen, lvb->numElements); - for (uint i = 0; i < lvb->numElements; i++) { - if (lvb->notNull.data()[i]) { - string tmpStr(reinterpret_cast(lvb->data.data()[i]), lvb->length.data()[i]); - if (vcType == orc::TypeKind::CHAR && tmpStr.back() == ' ') { - tmpStr.erase(tmpStr.find_last_not_of(" ") + 1); + auto numElements = lvb->numElements; + auto values = lvb->data.data(); + auto notNulls = lvb->notNull.data(); + auto lens = lvb->length.data(); + auto originalVector = new Vector>(numElements); + if (lvb->hasNulls) { + for (uint i = 0; i < numElements; i++) { + if (notNulls[i]) { + auto chars = reinterpret_cast(values[i]); + auto len = lens[i]; + FindLastNotEmpty(chars, len); + auto data = std::string_view(chars, len); + originalVector->SetValue(i, data); + } else { + originalVector->SetNull(i); + } + } + } else { + for (uint i = 0; i < numElements; i++) { + auto chars = reinterpret_cast(values[i]); + auto len = lens[i]; + FindLastNotEmpty(chars, len); + auto data = std::string_view(chars, len); + originalVector->SetValue(i, data); + } + } + return (uint64_t)originalVector; +} + +inline void TransferDecimal128(int64_t &highbits, uint64_t &lowbits) +{ + if (highbits < 0) { // int128's 2s' complement code + lowbits = ~lowbits + 1; // 2s' complement code + highbits = ~highbits; //1s' complement code + if (lowbits == 0) { + highbits += 1; // carry a number as in adding + } + highbits ^= ((uint64_t)1 << 63); + } +} + +uint64_t CopyToOmniDecimal128Vec(orc::ColumnVectorBatch *field) +{ + orc::Decimal128VectorBatch *lvb = dynamic_cast(field); + auto numElements = lvb->numElements; + auto values = lvb->values.data(); + auto notNulls = lvb->notNull.data(); + auto originalVector = new Vector(numElements); + if (lvb->hasNulls) { + for (uint i = 0; i < numElements; i++) { + if (notNulls[i]) { + auto highbits = values[i].getHighBits(); + auto lowbits = values[i].getLowBits(); + TransferDecimal128(highbits, lowbits); + Decimal128 d128(highbits, lowbits); + originalVector->SetValue(i, d128); + } else { + originalVector->SetNull(i); + } + } + } else { + for (uint i = 0; i < numElements; i++) { + auto highbits = values[i].getHighBits(); + auto lowbits = values[i].getLowBits(); + TransferDecimal128(highbits, lowbits); + Decimal128 d128(highbits, lowbits); + originalVector->SetValue(i, d128); + } + } + return (uint64_t)originalVector; +} + +uint64_t CopyToOmniDecimal64Vec(orc::ColumnVectorBatch *field) +{ + orc::Decimal64VectorBatch *lvb = dynamic_cast(field); + auto numElements = lvb->numElements; + auto values = lvb->values.data(); + auto notNulls = lvb->notNull.data(); + auto originalVector = new Vector(numElements); + if (lvb->hasNulls) { + for (uint i = 0; i < numElements; i++) { + if (!notNulls[i]) { + originalVector->SetNull(i); } - originalVector->SetValue(i, reinterpret_cast(tmpStr.data()), tmpStr.length()); - } else { - originalVector->SetValueNull(i); } } + originalVector->SetValues(0, values, numElements); return (uint64_t)originalVector; } -int copyToOmniVec(orc::TypeKind vcType, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field, ...) +int CopyToOmniVec(const orc::Type *type, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field) { - switch (vcType) { - case orc::TypeKind::BOOLEAN: { + switch (type->getKind()) { + case orc::TypeKind::BOOLEAN: omniTypeId = static_cast(OMNI_BOOLEAN); - omniVecId = copyFixwidth(field); + omniVecId = CopyFixedWidth(field); break; - } - case orc::TypeKind::SHORT: { + case orc::TypeKind::SHORT: omniTypeId = static_cast(OMNI_SHORT); - omniVecId = copyFixwidth(field); + omniVecId = CopyFixedWidth(field); break; - } - case orc::TypeKind::DATE: { + case orc::TypeKind::DATE: omniTypeId = static_cast(OMNI_DATE32); - omniVecId = copyFixwidth(field); + omniVecId = CopyFixedWidth(field); break; - } - case orc::TypeKind::INT: { + case orc::TypeKind::INT: omniTypeId = static_cast(OMNI_INT); - omniVecId = copyFixwidth(field); + omniVecId = CopyFixedWidth(field); break; - } - case orc::TypeKind::LONG: { + case orc::TypeKind::LONG: omniTypeId = static_cast(OMNI_LONG); - omniVecId = copyFixwidth(field); + omniVecId = CopyOptimizedForInt64(field); break; - } - case orc::TypeKind::DOUBLE: { + case orc::TypeKind::DOUBLE: omniTypeId = static_cast(OMNI_DOUBLE); - omniVecId = copyFixwidth(field); + omniVecId = CopyOptimizedForInt64(field); break; - } case orc::TypeKind::CHAR: + omniTypeId = static_cast(OMNI_VARCHAR); + omniVecId = CopyCharType(field); + break; case orc::TypeKind::STRING: - case orc::TypeKind::VARCHAR: { + case orc::TypeKind::VARCHAR: omniTypeId = static_cast(OMNI_VARCHAR); - va_list args; - va_start(args, field); - omniVecId = (uint64_t)copyVarwidth(va_arg(args, int), field, vcType); - va_end(args); + omniVecId = CopyVarWidth(field); break; - } - default: { - throw std::runtime_error("Native ColumnarFileScan Not support For This Type: " + vcType); - } - } - return 1; -} - -int copyToOmniDecimalVec(int precision, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field) -{ - VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); - if (precision > 18) { - omniTypeId = static_cast(OMNI_DECIMAL128); - orc::Decimal128VectorBatch *lvb = dynamic_cast(field); - FixedWidthVector *originalVector = - new FixedWidthVector(allocator, lvb->numElements); - for (uint i = 0; i < lvb->numElements; i++) { - if (lvb->notNull.data()[i]) { - int64_t highbits = lvb->values.data()[i].getHighBits(); - uint64_t lowbits = lvb->values.data()[i].getLowBits(); - if (highbits < 0) { // int128's 2s' complement code - lowbits = ~lowbits + 1; // 2s' complement code - highbits = ~highbits; //1s' complement code - if (lowbits == 0) { - highbits += 1; // carry a number as in adding - } - highbits ^= ((uint64_t)1 << 63); - } - Decimal128 d128(highbits, lowbits); - originalVector->SetValue(i, d128); + case orc::TypeKind::DECIMAL: + if (type->getPrecision() > MAX_DECIMAL64_DIGITS) { + omniTypeId = static_cast(OMNI_DECIMAL128); + omniVecId = CopyToOmniDecimal128Vec(field); } else { - originalVector->SetValueNull(i); - } - } - omniVecId = (uint64_t)originalVector; - } else { - omniTypeId = static_cast(OMNI_DECIMAL64); - orc::Decimal64VectorBatch *lvb = dynamic_cast(field); - FixedWidthVector *originalVector = new FixedWidthVector(allocator, lvb->numElements); - for (uint i = 0; i < lvb->numElements; i++) { - if (lvb->notNull.data()[i]) { - originalVector->SetValue(i, (int64_t)(lvb->values.data()[i])); - } else { - originalVector->SetValueNull(i); + omniTypeId = static_cast(OMNI_DECIMAL64); + omniVecId = CopyToOmniDecimal64Vec(field); } + break; + default: { + throw std::runtime_error("Native ColumnarFileScan Not support For This Type: " + type->getKind()); } - omniVecId = (uint64_t)originalVector; } return 1; } @@ -491,16 +514,10 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniRe vecCnt = root->fields.size(); batchRowSize = root->fields[0]->numElements; for (int id = 0; id < vecCnt; id++) { - orc::TypeKind vcType = baseTp.getSubtype(id)->getKind(); - int maxLen = baseTp.getSubtype(id)->getMaximumLength(); + auto type = baseTp.getSubtype(id); int omniTypeId = 0; uint64_t omniVecId = 0; - if (vcType != orc::TypeKind::DECIMAL) { - copyToOmniVec(vcType, omniTypeId, omniVecId, root->fields[id], maxLen); - } else { - copyToOmniDecimalVec(baseTp.getSubtype(id)->getPrecision(), omniTypeId, omniVecId, - root->fields[id]); - } + CopyToOmniVec(type, omniTypeId, omniVecId, root->fields[id]); env->SetIntArrayRegion(typeId, id, 1, &omniTypeId); jlong omniVec = static_cast(omniVecId); env->SetLongArrayRegion(vecNativeId, id, 1, &omniVec); diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.h b/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.h index 975de176f9c99f5bb78001a3beb88db5d43d9298..878af0242301d5e4d0ac3375108a9b7e86d7ee58 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.h +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.h @@ -22,28 +22,29 @@ #ifndef THESTRAL_PLUGIN_ORCCOLUMNARBATCHJNIREADER_H #define THESTRAL_PLUGIN_ORCCOLUMNARBATCHJNIREADER_H -#include "orc/ColumnPrinter.hh" -#include "orc/Exceptions.hh" -#include "orc/Type.hh" -#include "orc/Vector.hh" -#include "orc/Reader.hh" -#include "orc/OrcFile.hh" -#include "orc/MemoryPool.hh" -#include "orc/sargs/SearchArgument.hh" -#include "orc/sargs/Literal.hh" -#include -#include #include #include #include -#include -#include "jni.h" -#include "json/json.h" -#include "vector/vector_common.h" -#include "util/omni_exception.h" -#include +#include +#include #include -#include "../common/debug.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/debug.h" + +#include "io/orcfile/OrcFileRewrite.hh" #ifdef __cplusplus extern "C" { @@ -135,18 +136,14 @@ JNIEXPORT jobjectArray JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBat JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_getNumberOfRows(JNIEnv *env, jobject jObj, jlong rowReader, jlong batch); -int getLiteral(orc::Literal &lit, int leafType, std::string value); - -int buildLeaves(PredicateOperatorType leafOp, std::vector &litList, orc::Literal &lit, std::string leafNameString, orc::PredicateDataType leafType, - orc::SearchArgumentBuilder &builder); - -bool stringToBool(std::string boolStr); +int GetLiteral(orc::Literal &lit, int leafType, const std::string &value); -int copyToOmniVec(orc::TypeKind vcType, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field, ...); +int BuildLeaves(PredicateOperatorType leafOp, std::vector &litList, orc::Literal &lit, + const std::string &leafNameString, orc::PredicateDataType leafType, orc::SearchArgumentBuilder &builder); -int copyToOmniDecimalVec(int precision, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field); +bool StringToBool(const std::string &boolStr); -int copyToOmniDecimalVec(int precision, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field); +int CopyToOmniVec(const orc::Type *type, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field); #ifdef __cplusplus } diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/ParquetColumnarBatchJniReader.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/ParquetColumnarBatchJniReader.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fda647658a2477e3c7ac213fa9223a64cab09f39 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/ParquetColumnarBatchJniReader.cpp @@ -0,0 +1,122 @@ +/** + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ParquetColumnarBatchJniReader.h" +#include "jni_common.h" +#include "tablescan/ParquetReader.h" + +using namespace omniruntime::vec; +using namespace omniruntime::type; +using namespace std; +using namespace arrow; +using namespace parquet::arrow; +using namespace spark::reader; + +std::vector GetIndices(JNIEnv *env, jobject jsonObj, const char* name) +{ + jintArray indicesArray = (jintArray)env->CallObjectMethod(jsonObj, jsonMethodObj, env->NewStringUTF(name)); + auto length = static_cast(env->GetArrayLength(indicesArray)); + auto ptr = env->GetIntArrayElements(indicesArray, JNI_FALSE); + std::vector indices; + for (int32_t i = 0; i < length; i++) { + indices.push_back(ptr[i]); + } + env->ReleaseIntArrayElements(indicesArray, ptr, 0); + return indices; +} + +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader_initializeReader(JNIEnv *env, + jobject jObj, jobject jsonObj) +{ + JNI_FUNC_START + // Get filePath + jstring path = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("filePath")); + const char *filePath = env->GetStringUTFChars(path, JNI_FALSE); + std::string file(filePath); + env->ReleaseStringUTFChars(path, filePath); + + jstring ugiTemp = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("ugi")); + const char *ugi = env->GetStringUTFChars(ugiTemp, JNI_FALSE); + std::string ugiString(ugi); + env->ReleaseStringUTFChars(ugiTemp, ugi); + + // Get capacity for each record batch + int64_t capacity = (int64_t)env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("capacity")); + + // Get RowGroups and Columns indices + auto row_group_indices = GetIndices(env, jsonObj, "rowGroupIndices"); + auto column_indices = GetIndices(env, jsonObj, "columnIndices"); + + ParquetReader *pReader = new ParquetReader(); + auto state = pReader->InitRecordReader(file, capacity, row_group_indices, column_indices, ugiString); + if (state != Status::OK()) { + env->ThrowNew(runtimeExceptionClass, state.ToString().c_str()); + return 0; + } + return (jlong)(pReader); + JNI_FUNC_END(runtimeExceptionClass) +} + +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader_recordReaderNext(JNIEnv *env, + jobject jObj, jlong reader, jintArray typeId, jlongArray vecNativeId) +{ + JNI_FUNC_START + ParquetReader *pReader = (ParquetReader *)reader; + std::shared_ptr recordBatchPtr; + auto state = pReader->ReadNextBatch(&recordBatchPtr); + if (state != Status::OK()) { + env->ThrowNew(runtimeExceptionClass, state.ToString().c_str()); + return 0; + } + int vecCnt = 0; + long batchRowSize = 0; + if (recordBatchPtr != NULL) { + batchRowSize = recordBatchPtr->num_rows(); + vecCnt = recordBatchPtr->num_columns(); + std::vector> fields = recordBatchPtr->schema()->fields(); + + for (int colIdx = 0; colIdx < vecCnt; colIdx++) { + std::shared_ptr array = recordBatchPtr->column(colIdx); + // One array in current batch + std::shared_ptr data = array->data(); + int omniTypeId = 0; + uint64_t omniVecId = 0; + spark::reader::CopyToOmniVec(data->type, omniTypeId, omniVecId, array); + + env->SetIntArrayRegion(typeId, colIdx, 1, &omniTypeId); + jlong omniVec = static_cast(omniVecId); + env->SetLongArrayRegion(vecNativeId, colIdx, 1, &omniVec); + } + } + return (jlong)batchRowSize; + JNI_FUNC_END(runtimeExceptionClass) +} + +JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader_recordReaderClose(JNIEnv *env, + jobject jObj, jlong reader) +{ + JNI_FUNC_START + ParquetReader *pReader = (ParquetReader *)reader; + if (nullptr == pReader) { + env->ThrowNew(runtimeExceptionClass, "delete nullptr error for parquet reader"); + return; + } + delete pReader; + JNI_FUNC_END_VOID(runtimeExceptionClass) +} diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/ParquetColumnarBatchJniReader.h b/omnioperator/omniop-spark-extension/cpp/src/jni/ParquetColumnarBatchJniReader.h new file mode 100644 index 0000000000000000000000000000000000000000..9f47c6fb7a4731a53191e0301ed64dc7da1a282b --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/ParquetColumnarBatchJniReader.h @@ -0,0 +1,70 @@ +/** + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_THESTRAL_PLUGIN_PARQUETCOLUMNARBATCHJNIREADER_H +#define SPARK_THESTRAL_PLUGIN_PARQUETCOLUMNARBATCHJNIREADER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/debug.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader + * Method: initializeReader + * Signature: (Ljava/lang/String;Lorg/json/simple/JSONObject;)J + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader_initializeReader + (JNIEnv* env, jobject jObj, jobject job); + +/* + * Class: com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader + * Method: recordReaderNext + * Signature: (J[I[J)J + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader_recordReaderNext + (JNIEnv *, jobject, jlong, jintArray, jlongArray); + +/* + * Class: com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader + * Method: recordReaderClose + * Signature: (J)F + */ +JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader_recordReaderClose + (JNIEnv *, jobject, jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp index 2f75c23a770b8d40d61ec575b035f899ed22decb..ca982c0a4ca56100cb6c11599d6d0c334009da92 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp @@ -20,68 +20,31 @@ #include #include -#include "../io/SparkFile.hh" -#include "../io/ColumnWriter.hh" -#include "../shuffle/splitter.h" +#include "io/SparkFile.hh" +#include "io/ColumnWriter.hh" #include "jni_common.h" #include "SparkJniWrapper.hh" -#include "concurrent_map.h" - -static jint JNI_VERSION = JNI_VERSION_1_8; - -static jclass split_result_class; -static jclass runtime_exception_class; - -static jmethodID split_result_constructor; using namespace spark; using namespace google::protobuf::io; using namespace omniruntime::vec; -static ConcurrentMap> shuffle_splitter_holder_; - -jint JNI_OnLoad(JavaVM* vm, void* reserved) { - JNIEnv* env; - if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { - return JNI_ERR; - } - - split_result_class = - CreateGlobalClassReference(env, "Lcom/huawei/boostkit/spark/vectorized/SplitResult;"); - split_result_constructor = GetMethodID(env, split_result_class, "", "(JJJJJ[J)V"); - - runtime_exception_class = CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;"); - - return JNI_VERSION; -} - -void JNI_OnUnload(JavaVM* vm, void* reserved) { - JNIEnv* env; - vm->GetEnv(reinterpret_cast(&env), JNI_VERSION); - - env->DeleteGlobalRef(split_result_class); - - env->DeleteGlobalRef(runtime_exception_class); - - shuffle_splitter_holder_.Clear(); -} - -JNIEXPORT jlong JNICALL -Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativeMake( +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativeMake( JNIEnv* env, jobject, jstring partitioning_name_jstr, jint num_partitions, jstring jInputType, jint jNumCols, jint buffer_size, jstring compression_type_jstr, jstring data_file_jstr, jint num_sub_dirs, jstring local_dirs_jstr, jlong compress_block_size, - jint spill_batch_row, jlong spill_memory_threshold) { + jint spill_batch_row, jlong spill_memory_threshold) +{ JNI_FUNC_START if (partitioning_name_jstr == nullptr) { - env->ThrowNew(runtime_exception_class, - std::string("Short partitioning name can't be null").c_str()); + env->ThrowNew(runtimeExceptionClass, + std::string("Short partitioning name can't be null").c_str()); return 0; } if (jInputType == nullptr) { - env->ThrowNew(runtime_exception_class, - std::string("input types can't be null").c_str()); + env->ThrowNew(runtimeExceptionClass, + std::string("input types can't be null").c_str()); return 0; } @@ -89,17 +52,17 @@ Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativeMake( DataTypes inputVecTypes = Deserialize(inputTypeCharPtr); const int32_t *inputVecTypeIds = inputVecTypes.GetIds(); // - std::vector inputDataTpyes = inputVecTypes.Get(); - int32_t size = inputDataTpyes.size(); + std::vector inputDataTypes = inputVecTypes.Get(); + int32_t size = inputDataTypes.size(); uint32_t *inputDataPrecisions = new uint32_t[size]; uint32_t *inputDataScales = new uint32_t[size]; for (int i = 0; i < size; ++i) { - if(inputDataTpyes[i]->GetId() == OMNI_DECIMAL64 || inputDataTpyes[i]->GetId() == OMNI_DECIMAL128) { - inputDataScales[i] = std::dynamic_pointer_cast(inputDataTpyes[i])->GetScale(); - inputDataPrecisions[i] = std::dynamic_pointer_cast(inputDataTpyes[i])->GetPrecision(); + if (inputDataTypes[i]->GetId() == OMNI_DECIMAL64 || inputDataTypes[i]->GetId() == OMNI_DECIMAL128) { + inputDataScales[i] = std::dynamic_pointer_cast(inputDataTypes[i])->GetScale(); + inputDataPrecisions[i] = std::dynamic_pointer_cast(inputDataTypes[i])->GetPrecision(); } } - inputDataTpyes.clear(); + inputDataTypes.clear(); InputDataTypes inputDataTypesTmp; inputDataTypesTmp.inputVecTypeIds = (int32_t *)inputVecTypeIds; @@ -107,13 +70,13 @@ Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativeMake( inputDataTypesTmp.inputDataScales = inputDataScales; if (data_file_jstr == nullptr) { - env->ThrowNew(runtime_exception_class, - std::string("Shuffle DataFile can't be null").c_str()); + env->ThrowNew(runtimeExceptionClass, + std::string("Shuffle DataFile can't be null").c_str()); return 0; } if (local_dirs_jstr == nullptr) { - env->ThrowNew(runtime_exception_class, - std::string("Shuffle DataFile can't be null").c_str()); + env->ThrowNew(runtimeExceptionClass, + std::string("Shuffle DataFile can't be null").c_str()); return 0; } @@ -141,40 +104,38 @@ Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativeMake( setenv("NATIVESQL_SPARK_LOCAL_DIRS", local_dirs, 1); env->ReleaseStringUTFChars(local_dirs_jstr, local_dirs); - if (spill_batch_row > 0){ + if (spill_batch_row > 0) { splitOptions.spill_batch_row_num = spill_batch_row; } - if (spill_memory_threshold > 0){ + if (spill_memory_threshold > 0) { splitOptions.spill_mem_threshold = spill_memory_threshold; } - if (compress_block_size > 0){ + if (compress_block_size > 0) { splitOptions.compress_block_size = compress_block_size; } - jclass cls = env->FindClass("java/lang/Thread"); - jmethodID mid = env->GetStaticMethodID(cls, "currentThread", "()Ljava/lang/Thread;"); - jobject thread = env->CallStaticObjectMethod(cls, mid); + jobject thread = env->CallStaticObjectMethod(threadClass, currentThread); if (thread == NULL) { std::cout << "Thread.currentThread() return NULL" <GetMethodID(cls, "getId", "()J"); - jlong sid = env->CallLongMethod(thread, mid_getid); + jlong sid = env->CallLongMethod(thread, threadGetId); splitOptions.thread_id = (int64_t)sid; } - auto splitter = Splitter::Make(partitioning_name, inputDataTypesTmp, jNumCols, num_partitions, std::move(splitOptions)); - return shuffle_splitter_holder_.Insert(std::shared_ptr(splitter)); - JNI_FUNC_END(runtime_exception_class) + auto splitter = Splitter::Make(partitioning_name, inputDataTypesTmp, jNumCols, num_partitions, + std::move(splitOptions)); + return g_shuffleSplitterHolder.Insert(std::shared_ptr(splitter)); + JNI_FUNC_END(runtimeExceptionClass) } -JNIEXPORT jlong JNICALL -Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split( - JNIEnv *env, jobject jObj, jlong splitter_id, jlong jVecBatchAddress) { +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split( + JNIEnv *env, jobject jObj, jlong splitter_id, jlong jVecBatchAddress) +{ JNI_FUNC_START - auto splitter = shuffle_splitter_holder_.Lookup(splitter_id); + auto splitter = g_shuffleSplitterHolder.Lookup(splitter_id); if (!splitter) { std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); - env->ThrowNew(runtime_exception_class, error_message.c_str()); + env->ThrowNew(runtimeExceptionClass, error_message.c_str()); return -1; } @@ -182,17 +143,17 @@ Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split( splitter->Split(*vecBatch); return 0L; - JNI_FUNC_END(runtime_exception_class) + JNI_FUNC_END(runtimeExceptionClass) } -JNIEXPORT jobject JNICALL -Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_stop( - JNIEnv* env, jobject, jlong splitter_id) { +JNIEXPORT jobject JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_stop( + JNIEnv* env, jobject, jlong splitter_id) +{ JNI_FUNC_START - auto splitter = shuffle_splitter_holder_.Lookup(splitter_id); + auto splitter = g_shuffleSplitterHolder.Lookup(splitter_id); if (!splitter) { std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); - env->ThrowNew(runtime_exception_class, error_message.c_str()); + env->ThrowNew(runtimeExceptionClass, error_message.c_str()); } splitter->Stop(); @@ -201,23 +162,23 @@ Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_stop( auto src = reinterpret_cast(partition_length.data()); env->SetLongArrayRegion(partition_length_arr, 0, partition_length.size(), src); jobject split_result = env->NewObject( - split_result_class, split_result_constructor, splitter->TotalComputePidTime(), + splitResultClass, splitResultConstructor, splitter->TotalComputePidTime(), splitter->TotalWriteTime(), splitter->TotalSpillTime(), splitter->TotalBytesWritten(), splitter->TotalBytesSpilled(), partition_length_arr); return split_result; - JNI_FUNC_END(runtime_exception_class) + JNI_FUNC_END(runtimeExceptionClass) } -JNIEXPORT void JNICALL -Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_close( - JNIEnv* env, jobject, jlong splitter_id) { +JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_close( + JNIEnv* env, jobject, jlong splitter_id) +{ JNI_FUNC_START - auto splitter = shuffle_splitter_holder_.Lookup(splitter_id); + auto splitter = g_shuffleSplitterHolder.Lookup(splitter_id); if (!splitter) { std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); - env->ThrowNew(runtime_exception_class, error_message.c_str()); + env->ThrowNew(runtimeExceptionClass, error_message.c_str()); } - shuffle_splitter_holder_.Erase(splitter_id); - JNI_FUNC_END_VOID(runtime_exception_class) + g_shuffleSplitterHolder.Erase(splitter_id); + JNI_FUNC_END_VOID(runtimeExceptionClass) } diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.hh b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.hh index 91ff665e4ea2448295722b9260615207074d801d..c98c10383c4cab04c4770adb8ebdab0ebdb4424b 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.hh +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.hh @@ -20,6 +20,8 @@ #include #include #include +#include "concurrent_map.h" +#include "shuffle/splitter.h" #ifndef SPARK_JNI_WRAPPER #define SPARK_JNI_WRAPPER @@ -51,6 +53,8 @@ JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_close( JNIEnv* env, jobject, jlong splitter_id); +static ConcurrentMap> g_shuffleSplitterHolder; + #ifdef __cplusplus } #endif diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.cpp index 4beb855ca4c3dd271648bca85e1a1f015d4a0f84..f0e3a225363913268b885cbcde2903d00eea7476 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.cpp @@ -21,8 +21,31 @@ #define THESTRAL_PLUGIN_MASTER_JNI_COMMON_CPP #include "jni_common.h" +#include "io/SparkFile.hh" +#include "SparkJniWrapper.hh" -spark::CompressionKind GetCompressionType(JNIEnv* env, jstring codec_jstr) { +jclass runtimeExceptionClass; +jclass splitResultClass; +jclass jsonClass; +jclass arrayListClass; +jclass threadClass; + +jmethodID jsonMethodInt; +jmethodID jsonMethodLong; +jmethodID jsonMethodHas; +jmethodID jsonMethodString; +jmethodID jsonMethodJsonObj; +jmethodID arrayListGet; +jmethodID arrayListSize; +jmethodID jsonMethodObj; +jmethodID splitResultConstructor; +jmethodID currentThread; +jmethodID threadGetId; + +static jint JNI_VERSION = JNI_VERSION_1_8; + +spark::CompressionKind GetCompressionType(JNIEnv* env, jstring codec_jstr) +{ auto codec_c = env->GetStringUTFChars(codec_jstr, JNI_FALSE); auto codec = std::string(codec_c); auto compression_type = GetCompressionType(codec); @@ -30,16 +53,64 @@ spark::CompressionKind GetCompressionType(JNIEnv* env, jstring codec_jstr) { return compression_type; } -jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) { +jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) +{ jclass local_class = env->FindClass(class_name); jclass global_class = (jclass)env->NewGlobalRef(local_class); env->DeleteLocalRef(local_class); return global_class; } -jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const char* sig) { +jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const char* sig) +{ jmethodID ret = env->GetMethodID(this_class, name, sig); return ret; } +jint JNI_OnLoad(JavaVM* vm, void* reserved) +{ + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { + return JNI_ERR; + } + + runtimeExceptionClass = CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;"); + + splitResultClass = + CreateGlobalClassReference(env, "Lcom/huawei/boostkit/spark/vectorized/SplitResult;"); + splitResultConstructor = GetMethodID(env, splitResultClass, "", "(JJJJJ[J)V"); + + jsonClass = CreateGlobalClassReference(env, "org/json/JSONObject"); + jsonMethodInt = env->GetMethodID(jsonClass, "getInt", "(Ljava/lang/String;)I"); + jsonMethodLong = env->GetMethodID(jsonClass, "getLong", "(Ljava/lang/String;)J"); + jsonMethodHas = env->GetMethodID(jsonClass, "has", "(Ljava/lang/String;)Z"); + jsonMethodString = env->GetMethodID(jsonClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;"); + jsonMethodJsonObj = env->GetMethodID(jsonClass, "getJSONObject", "(Ljava/lang/String;)Lorg/json/JSONObject;"); + jsonMethodObj = env->GetMethodID(jsonClass, "get", "(Ljava/lang/String;)Ljava/lang/Object;"); + + arrayListClass = CreateGlobalClassReference(env, "java/util/ArrayList"); + arrayListGet = env->GetMethodID(arrayListClass, "get", "(I)Ljava/lang/Object;"); + arrayListSize = env->GetMethodID(arrayListClass, "size", "()I"); + + threadClass = CreateGlobalClassReference(env, "java/lang/Thread"); + currentThread = env->GetStaticMethodID(threadClass, "currentThread", "()Ljava/lang/Thread;"); + threadGetId = env->GetMethodID(threadClass, "getId", "()J"); + + return JNI_VERSION; +} + +void JNI_OnUnload(JavaVM* vm, void* reserved) +{ + JNIEnv* env; + vm->GetEnv(reinterpret_cast(&env), JNI_VERSION); + + env->DeleteGlobalRef(runtimeExceptionClass); + env->DeleteGlobalRef(splitResultClass); + env->DeleteGlobalRef(jsonClass); + env->DeleteGlobalRef(arrayListClass); + env->DeleteGlobalRef(threadClass); + + g_shuffleSplitterHolder.Clear(); +} + #endif //THESTRAL_PLUGIN_MASTER_JNI_COMMON_CPP diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h index e21fd444d383263f47ba4773ca70173f736c0093..4b59296e152876062a06db3d69c81a7ed22b670b 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h @@ -21,8 +21,7 @@ #define THESTRAL_PLUGIN_MASTER_JNI_COMMON_H #include - -#include "../common/common.h" +#include "common/common.h" spark::CompressionKind GetCompressionType(JNIEnv* env, jstring codec_jstr); @@ -49,4 +48,22 @@ jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const ch return; \ } \ +extern jclass runtimeExceptionClass; +extern jclass splitResultClass; +extern jclass jsonClass; +extern jclass arrayListClass; +extern jclass threadClass; + +extern jmethodID jsonMethodInt; +extern jmethodID jsonMethodLong; +extern jmethodID jsonMethodHas; +extern jmethodID jsonMethodString; +extern jmethodID jsonMethodJsonObj; +extern jmethodID arrayListGet; +extern jmethodID arrayListSize; +extern jmethodID jsonMethodObj; +extern jmethodID splitResultConstructor; +extern jmethodID currentThread; +extern jmethodID threadGetId; + #endif //THESTRAL_PLUGIN_MASTER_JNI_COMMON_H diff --git a/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto b/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto index c40472020171692ea7b0acde2dd873efeda691f4..725f9fa070aa1f8d188d85118df9765a63d299f3 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto +++ b/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto @@ -57,4 +57,4 @@ message VecType { NANOSEC = 3; } TimeUnit timeUnit = 6; -} \ No newline at end of file +} diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp index 8e66120276c9aee186c9effd13157d2992190510..addc16c71b2faa4fc5ba7f9f2169b5adc8b6390c 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp @@ -37,10 +37,10 @@ int Splitter::ComputeAndCountPartitionId(VectorBatch& vb) { partition_id_[i] = 0; } } else { - IntVector* hashVct = static_cast(vb.GetVector(0)); + auto hash_vct = reinterpret_cast *>(vb.Get(0)); for (auto i = 0; i < num_rows; ++i) { // positive mod - int32_t pid = hashVct->GetValue(i); + int32_t pid = hash_vct->GetValue(i); if (pid >= num_partitions_) { LogsError(" Illegal pid Value: %d >= partition number %d .", pid, num_partitions_); throw std::runtime_error("Shuffle pidVec Illegal pid Value!"); @@ -76,7 +76,7 @@ int Splitter::AllocatePartitionBuffers(int32_t partition_id, int32_t new_size) { case SHUFFLE_8BYTE: case SHUFFLE_DECIMAL128: default: { - void *ptr_tmp = static_cast(options_.allocator->alloc(new_size * (1 << column_type_id_[i]))); + void *ptr_tmp = static_cast(options_.allocator->Alloc(new_size * (1 << column_type_id_[i]))); fixed_valueBuffer_size_[partition_id] = new_size * (1 << column_type_id_[i]); if (nullptr == ptr_tmp) { throw std::runtime_error("Allocator for AllocatePartitionBuffers Failed! "); @@ -128,15 +128,11 @@ int Splitter::SplitFixedWidthValueBuffer(VectorBatch& vb) { auto col_idx_vb = fixed_width_array_idx_[col]; auto col_idx_schema = singlePartitionFlag ? col_idx_vb : (col_idx_vb - 1); const auto& dst_addrs = partition_fixed_width_value_addrs_[col]; - if (vb.GetVector(col_idx_vb)->GetEncoding() == OMNI_VEC_ENCODING_DICTIONARY) { + if (vb.Get(col_idx_vb)->GetEncoding() == OMNI_DICTIONARY) { LogsDebug("Dictionary Columnar process!"); - auto ids_tmp = static_cast(options_.allocator->alloc(num_rows * sizeof(int32_t))); - Buffer *ids (new Buffer((uint8_t*)ids_tmp, 0, num_rows * sizeof(int32_t))); - if (ids->data_ == nullptr) { - throw std::runtime_error("Allocator for SplitFixedWidthValueBuffer ids Failed! "); - } - auto dictionaryTmp = ((DictionaryVector *)(vb.GetVector(col_idx_vb)))->ExtractDictionaryAndIds(0, num_rows, (int32_t *)(ids->data_)); - auto src_addr = VectorHelper::GetValuesAddr(dictionaryTmp); + + auto ids_addr = VectorHelper::UnsafeGetValues(vb.Get(col_idx_vb)); + auto src_addr = reinterpret_cast(VectorHelper::UnsafeGetDictionary(vb.Get(col_idx_vb))); switch (column_type_id_[col_idx_schema]) { #define PROCESS(SHUFFLE_TYPE, CTYPE) \ case SHUFFLE_TYPE: \ @@ -145,8 +141,8 @@ int Splitter::SplitFixedWidthValueBuffer(VectorBatch& vb) { auto dst_offset = \ partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; \ reinterpret_cast(dst_addrs[pid])[dst_offset] = \ - reinterpret_cast(src_addr)[reinterpret_cast(ids->data_)[row]]; \ - partition_fixed_width_buffers_[col][pid][1]->size_ += (1 << SHUFFLE_TYPE); \ + reinterpret_cast(src_addr)[reinterpret_cast(ids_addr)[row]]; \ + partition_fixed_width_buffers_[col][pid][1]->size_ += (1 << SHUFFLE_TYPE); \ partition_buffer_idx_offset_[pid]++; \ } \ break; @@ -160,10 +156,12 @@ int Splitter::SplitFixedWidthValueBuffer(VectorBatch& vb) { auto pid = partition_id_[row]; auto dst_offset = partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; + // 前64位取值、赋值 reinterpret_cast(dst_addrs[pid])[dst_offset << 1] = - reinterpret_cast(src_addr)[reinterpret_cast(ids->data_)[row] << 1]; // 前64位取值、赋值 - reinterpret_cast(dst_addrs[pid])[dst_offset << 1 | 1] = - reinterpret_cast(src_addr)[reinterpret_cast(ids->data_)[row] << 1 | 1]; // 后64位取值、赋值 + reinterpret_cast(src_addr)[reinterpret_cast(ids_addr)[row] << 1]; + // 后64位取值、赋值 + reinterpret_cast(dst_addrs[pid])[(dst_offset << 1) | 1] = + reinterpret_cast(src_addr)[(reinterpret_cast(ids_addr)[row] << 1) | 1]; partition_fixed_width_buffers_[col][pid][1]->size_ += (1 << SHUFFLE_DECIMAL128); //decimal128 16Bytes partition_buffer_idx_offset_[pid]++; @@ -174,13 +172,8 @@ int Splitter::SplitFixedWidthValueBuffer(VectorBatch& vb) { throw std::runtime_error("SplitFixedWidthValueBuffer not match this type: " + column_type_id_[col_idx_schema]); } } - options_.allocator->free(ids->data_, ids->capacity_); - if (nullptr == ids) { - throw std::runtime_error("delete nullptr error for ids"); - } - delete ids; } else { - auto src_addr = VectorHelper::GetValuesAddr(vb.GetVector(col_idx_vb)); + auto src_addr = reinterpret_cast(VectorHelper::UnsafeGetValues(vb.Get(col_idx_vb))); switch (column_type_id_[col_idx_schema]) { #define PROCESS(SHUFFLE_TYPE, CTYPE) \ case SHUFFLE_TYPE: \ @@ -225,53 +218,65 @@ int Splitter::SplitFixedWidthValueBuffer(VectorBatch& vb) { int Splitter::SplitBinaryArray(VectorBatch& vb) { - const auto numRows = vb.GetRowCount(); - auto vecCntVb = vb.GetVectorCount(); - auto vecCntSchema = singlePartitionFlag ? vecCntVb : vecCntVb - 1; - for (auto colSchema = 0; colSchema < vecCntSchema; ++colSchema) { - switch (column_type_id_[colSchema]) { + const auto num_rows = vb.GetRowCount(); + auto vec_cnt_vb = vb.GetVectorCount(); + auto vec_cnt_schema = singlePartitionFlag ? vec_cnt_vb : vec_cnt_vb - 1; + for (auto col_schema = 0; col_schema < vec_cnt_schema; ++col_schema) { + switch (column_type_id_[col_schema]) { case SHUFFLE_BINARY: { - auto colVb = singlePartitionFlag ? colSchema : colSchema + 1; - if (vb.GetVector(colVb)->GetEncoding() == OMNI_VEC_ENCODING_DICTIONARY) { - for (auto row = 0; row < numRows; ++row) { + auto col_vb = singlePartitionFlag ? col_schema : col_schema + 1; + varcharVectorCache.insert(vb.Get(col_vb)); + if (vb.Get(col_vb)->GetEncoding() == OMNI_DICTIONARY) { + auto vc = reinterpret_cast> *>( + vb.Get(col_vb)); + for (auto row = 0; row < num_rows; ++row) { auto pid = partition_id_[row]; uint8_t *dst = nullptr; - auto str_len = ((DictionaryVector *)(vb.GetVector(colVb)))->GetVarchar(row, &dst); - bool isnull = ((DictionaryVector *)(vb.GetVector(colVb)))->IsValueNull(row); + uint32_t str_len = 0; + if (!vc->IsNull(row)) { + std::string_view value = vc->GetValue(row); + dst = reinterpret_cast(reinterpret_cast(value.data())); + str_len = static_cast(value.length()); + } + bool is_null = vc->IsNull(row); cached_vectorbatch_size_ += str_len; // 累计变长部分cache数据 - VCLocation cl((uint64_t) dst, str_len, isnull); - if ((vc_partition_array_buffers_[pid][colSchema].size() != 0) && - (vc_partition_array_buffers_[pid][colSchema].back().getVcList().size() < + VCLocation cl((uint64_t) dst, str_len, is_null); + if ((vc_partition_array_buffers_[pid][col_schema].size() != 0) && + (vc_partition_array_buffers_[pid][col_schema].back().getVcList().size() < options_.spill_batch_row_num)) { - vc_partition_array_buffers_[pid][colSchema].back().getVcList().push_back(cl); - vc_partition_array_buffers_[pid][colSchema].back().vcb_total_len += str_len; + vc_partition_array_buffers_[pid][col_schema].back().getVcList().push_back(cl); + vc_partition_array_buffers_[pid][col_schema].back().vcb_total_len += str_len; } else { VCBatchInfo svc(options_.spill_batch_row_num); svc.getVcList().push_back(cl); svc.vcb_total_len += str_len; - vc_partition_array_buffers_[pid][colSchema].push_back(svc); + vc_partition_array_buffers_[pid][col_schema].push_back(svc); } } } else { - VarcharVector *vc = nullptr; - vc = static_cast(vb.GetVector(colVb)); - for (auto row = 0; row < numRows; ++row) { + auto vc = reinterpret_cast> *>(vb.Get(col_vb)); + for (auto row = 0; row < num_rows; ++row) { auto pid = partition_id_[row]; uint8_t *dst = nullptr; - int str_len = vc->GetValue(row, &dst); - bool isnull = vc->IsValueNull(row); + uint32_t str_len = 0; + if (!vc->IsNull(row)) { + std::string_view value = vc->GetValue(row); + dst = reinterpret_cast(reinterpret_cast(value.data())); + str_len = static_cast(value.length()); + } + bool is_null = vc->IsNull(row); cached_vectorbatch_size_ += str_len; // 累计变长部分cache数据 - VCLocation cl((uint64_t) dst, str_len, isnull); - if ((vc_partition_array_buffers_[pid][colSchema].size() != 0) && - (vc_partition_array_buffers_[pid][colSchema].back().getVcList().size() < + VCLocation cl((uint64_t) dst, str_len, is_null); + if ((vc_partition_array_buffers_[pid][col_schema].size() != 0) && + (vc_partition_array_buffers_[pid][col_schema].back().getVcList().size() < options_.spill_batch_row_num)) { - vc_partition_array_buffers_[pid][colSchema].back().getVcList().push_back(cl); - vc_partition_array_buffers_[pid][colSchema].back().vcb_total_len += str_len; + vc_partition_array_buffers_[pid][col_schema].back().getVcList().push_back(cl); + vc_partition_array_buffers_[pid][col_schema].back().vcb_total_len += str_len; } else { VCBatchInfo svc(options_.spill_batch_row_num); svc.getVcList().push_back(cl); svc.vcb_total_len += str_len; - vc_partition_array_buffers_[pid][colSchema].push_back(svc); + vc_partition_array_buffers_[pid][col_schema].push_back(svc); } } } @@ -296,7 +301,7 @@ int Splitter::SplitFixedWidthValidityBuffer(VectorBatch& vb){ if (partition_id_cnt_cur_[pid] > 0 && dst_addrs[pid] == nullptr) { // init bitmap if it's null auto new_size = partition_id_cnt_cur_[pid] > options_.buffer_size ? partition_id_cnt_cur_[pid] : options_.buffer_size; - auto ptr_tmp = static_cast(options_.allocator->alloc(new_size)); + auto ptr_tmp = static_cast(options_.allocator->Alloc(new_size)); if (nullptr == ptr_tmp) { throw std::runtime_error("Allocator for ValidityBuffer Failed! "); } @@ -309,7 +314,8 @@ int Splitter::SplitFixedWidthValidityBuffer(VectorBatch& vb){ } // 计算并填充数据 - auto src_addr = const_cast((uint8_t*)(VectorHelper::GetNullsAddr(vb.GetVector(col_idx)))); + auto src_addr = const_cast((uint8_t *)( + reinterpret_cast(omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vb.Get(col_idx))))); std::fill(std::begin(partition_buffer_idx_offset_), std::end(partition_buffer_idx_offset_), 0); const auto num_rows = vb.GetRowCount(); @@ -401,17 +407,20 @@ int Splitter::DoSplit(VectorBatch& vb) { // Binary split last vector batch... SplitBinaryArray(vb); - vectorBatch_cache_.push_back(&vb); // record for release vector + num_row_splited_ += vb.GetRowCount(); + // release the fixed width vector and release vectorBatch at the same time + ReleaseVectorBatch(&vb); // 阈值检查,是否溢写 - num_row_splited_ += vb.GetRowCount(); - if (num_row_splited_ + vb.GetRowCount() >= SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) { + if (num_row_splited_ >= SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) { LogsDebug(" Spill For Row Num Threshold."); TIME_NANO_OR_RAISE(total_spill_time_, SpillToTmpFile()); + isSpill = true; } if (cached_vectorbatch_size_ + current_fixed_alloc_buffer_size_ >= options_.spill_mem_threshold) { LogsDebug(" Spill For Memory Size Threshold."); TIME_NANO_OR_RAISE(total_spill_time_, SpillToTmpFile()); + isSpill = true; } return 0; } @@ -548,21 +557,21 @@ int Splitter::Split(VectorBatch& vb ) } std::shared_ptr Splitter::CaculateSpilledTmpFilePartitionOffsets() { - void *ptr_tmp = static_cast(options_.allocator->alloc((num_partitions_ + 1) * sizeof(uint32_t))); + void *ptr_tmp = static_cast(options_.allocator->Alloc((num_partitions_ + 1) * sizeof(uint64_t))); if (nullptr == ptr_tmp) { throw std::runtime_error("Allocator for partitionOffsets Failed! "); } - std::shared_ptr ptrPartitionOffsets (new Buffer((uint8_t*)ptr_tmp, 0, (num_partitions_ + 1) * sizeof(uint32_t))); - uint32_t pidOffset = 0; + std::shared_ptr ptrPartitionOffsets (new Buffer((uint8_t*)ptr_tmp, 0, (num_partitions_ + 1) * sizeof(uint64_t))); + uint64_t pidOffset = 0; // 顺序记录每个partition的offset auto pid = 0; for (pid = 0; pid < num_partitions_; ++pid) { - reinterpret_cast(ptrPartitionOffsets->data_)[pid] = pidOffset; + reinterpret_cast(ptrPartitionOffsets->data_)[pid] = pidOffset; pidOffset += partition_serialization_size_[pid]; // reset partition_cached_vectorbatch_size_ to 0 partition_serialization_size_[pid] = 0; } - reinterpret_cast(ptrPartitionOffsets->data_)[pid] = pidOffset; + reinterpret_cast(ptrPartitionOffsets->data_)[pid] = pidOffset; return ptrPartitionOffsets; } @@ -604,7 +613,7 @@ spark::VecType::VecTypeId CastShuffleTypeIdToVecType(int32_t tmpType) { return spark::VecType::VEC_TYPE_CHAR; case OMNI_CONTAINER: return spark::VecType::VEC_TYPE_CONTAINER; - case OMNI_INVALID: + case DataTypeId::OMNI_INVALID: return spark::VecType::VEC_TYPE_INVALID; default: { throw std::runtime_error("castShuffleTypeIdToVecType() unexpected ShuffleTypeId"); @@ -623,9 +632,9 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, colIndexTmpSchema = singlePartitionFlag ? fixed_width_array_idx_[fixColIndexTmp] : fixed_width_array_idx_[fixColIndexTmp] - 1; auto onceCopyLen = splitRowInfoTmp->onceCopyRow * (1 << column_type_id_[colIndexTmpSchema]); // 临时内存,拷贝拼接onceCopyRow批,用完释放 - void *ptr_value_tmp = static_cast(options_.allocator->alloc(onceCopyLen)); + void *ptr_value_tmp = static_cast(options_.allocator->Alloc(onceCopyLen)); std::shared_ptr ptr_value (new Buffer((uint8_t*)ptr_value_tmp, 0, onceCopyLen)); - void *ptr_validity_tmp = static_cast(options_.allocator->alloc(splitRowInfoTmp->onceCopyRow)); + void *ptr_validity_tmp = static_cast(options_.allocator->Alloc(splitRowInfoTmp->onceCopyRow)); std::shared_ptr ptr_validity (new Buffer((uint8_t*)ptr_validity_tmp, 0, splitRowInfoTmp->onceCopyRow)); if (nullptr == ptr_value->data_ || nullptr == ptr_validity->data_) { throw std::runtime_error("Allocator for tmp buffer Failed! "); @@ -657,9 +666,9 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_ + (splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp] / (1 << column_type_id_[colIndexTmpSchema])), memCopyLen / (1 << column_type_id_[colIndexTmpSchema])); // 释放内存 - options_.allocator->free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_, + options_.allocator->Free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_, partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->capacity_); - options_.allocator->free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->data_, + options_.allocator->Free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->data_, partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->capacity_); destCopyedLength += memCopyLen; splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp] += 1; // cacheBatchIndex下标后移 @@ -686,8 +695,8 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, vec.set_values(ptr_value->data_, onceCopyLen); vec.set_nulls(ptr_validity->data_, splitRowInfoTmp->onceCopyRow); // 临时内存,拷贝拼接onceCopyRow批,用完释放 - options_.allocator->free(ptr_value->data_, ptr_value->capacity_); - options_.allocator->free(ptr_validity->data_, ptr_validity->capacity_); + options_.allocator->Free(ptr_value->data_, ptr_value->capacity_); + options_.allocator->Free(ptr_validity->data_, ptr_validity->capacity_); } // partition_cached_vectorbatch_[partition_id][cache_index][col][0]代表ByteMap, // partition_cached_vectorbatch_[partition_id][cache_index][col][1]代表value @@ -712,6 +721,88 @@ void Splitter::SerializingBinaryColumns(int32_t partitionId, spark::Vec& vec, in vec.set_offset(OffsetsByte.get(), (itemsTotalLen + 1) * sizeof(int32_t)); } +int32_t Splitter::ProtoWritePartition(int32_t partition_id, std::unique_ptr &bufferStream, void *bufferOut, int32_t &sizeOut) { + SplitRowInfo splitRowInfoTmp; + splitRowInfoTmp.copyedRow = 0; + splitRowInfoTmp.remainCopyRow = partition_id_cnt_cache_[partition_id]; + splitRowInfoTmp.cacheBatchIndex.resize(fixed_width_array_idx_.size()); + splitRowInfoTmp.cacheBatchCopyedLen.resize(fixed_width_array_idx_.size()); + + int curBatch = 0; + while (0 < splitRowInfoTmp.remainCopyRow) { + if (options_.spill_batch_row_num < splitRowInfoTmp.remainCopyRow) { + splitRowInfoTmp.onceCopyRow = options_.spill_batch_row_num; + } else { + splitRowInfoTmp.onceCopyRow = splitRowInfoTmp.remainCopyRow; + } + + vecBatchProto->set_rowcnt(splitRowInfoTmp.onceCopyRow); + vecBatchProto->set_veccnt(column_type_id_.size()); + int fixColIndexTmp = 0; + for (size_t indexSchema = 0; indexSchema < column_type_id_.size(); indexSchema++) { + spark::Vec * vec = vecBatchProto->add_vecs(); + switch (column_type_id_[indexSchema]) { + case ShuffleTypeId::SHUFFLE_1BYTE: + case ShuffleTypeId::SHUFFLE_2BYTE: + case ShuffleTypeId::SHUFFLE_4BYTE: + case ShuffleTypeId::SHUFFLE_8BYTE: + case ShuffleTypeId::SHUFFLE_DECIMAL128:{ + SerializingFixedColumns(partition_id, *vec, fixColIndexTmp, &splitRowInfoTmp); + fixColIndexTmp++; // 定长序列化数量++ + break; + } + case ShuffleTypeId::SHUFFLE_BINARY: { + SerializingBinaryColumns(partition_id, *vec, indexSchema, curBatch); + break; + } + default: { + throw std::runtime_error("Unsupported ShuffleType."); + } + } + spark::VecType *vt = vec->mutable_vectype(); + vt->set_typeid_(CastShuffleTypeIdToVecType(vector_batch_col_types_[indexSchema])); + LogsDebug("precision[indexSchema %d]: %d , scale[indexSchema %d]: %d ", + indexSchema, input_col_types.inputDataPrecisions[indexSchema], + indexSchema, input_col_types.inputDataScales[indexSchema]); + if(vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL128 || vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL64){ + vt->set_precision(input_col_types.inputDataPrecisions[indexSchema]); + vt->set_scale(input_col_types.inputDataScales[indexSchema]); + } + } + curBatch++; + + if (vecBatchProto->ByteSizeLong() > UINT32_MAX) { + throw std::runtime_error("Unsafe static_cast long to uint_32t."); + } + uint32_t vecBatchProtoSize = reversebytes_uint32t(static_cast(vecBatchProto->ByteSizeLong())); + if (bufferStream->Next(&bufferOut, &sizeOut)) { + std::memcpy(bufferOut, &vecBatchProtoSize, sizeof(vecBatchProtoSize)); + if (sizeof(vecBatchProtoSize) < sizeOut) { + bufferStream->BackUp(sizeOut - sizeof(vecBatchProtoSize)); + } + } + + vecBatchProto->SerializeToZeroCopyStream(bufferStream.get()); + splitRowInfoTmp.remainCopyRow -= splitRowInfoTmp.onceCopyRow; + splitRowInfoTmp.copyedRow += splitRowInfoTmp.onceCopyRow; + vecBatchProto->Clear(); + } + + uint64_t partitionBatchSize = bufferStream->flush(); + total_bytes_written_ += partitionBatchSize; + partition_lengths_[partition_id] += partitionBatchSize; + LogsDebug(" partitionBatch write length: %lu", partitionBatchSize); + + // 及时清理分区数据 + partition_cached_vectorbatch_[partition_id].clear(); // 定长数据内存释放 + for (size_t col = 0; col < column_type_id_.size(); col++) { + vc_partition_array_buffers_[partition_id][col].clear(); // binary 释放内存 + } + + return 0; + +} + int Splitter::protoSpillPartition(int32_t partition_id, std::unique_ptr &bufferStream) { SplitRowInfo splitRowInfoTmp; splitRowInfoTmp.copyedRow = 0; @@ -819,6 +910,11 @@ int Splitter::WriteDataFileProto() { } void Splitter::MergeSpilled() { + for (auto pid = 0; pid < num_partitions_; ++pid) { + CacheVectorBatch(pid, true); + partition_buffer_size_[pid] = 0; //溢写之后将其清零,条件溢写需要重新分配内存 + } + std::unique_ptr outStream = writeLocalFile(options_.data_file); LogsDebug(" Merge Spilled Tmp File: %s ", options_.data_file.c_str()); WriterOptions options; @@ -831,17 +927,18 @@ void Splitter::MergeSpilled() { void* bufferOut = nullptr; int sizeOut = 0; for (int pid = 0; pid < num_partitions_; pid++) { + ProtoWritePartition(pid, bufferOutPutStream, bufferOut, sizeOut); LogsDebug(" MergeSplled traversal partition( %d ) ",pid); for (auto &pair : spilled_tmp_files_info_) { auto tmpDataFilePath = pair.first + ".data"; - auto tmpPartitionOffset = reinterpret_cast(pair.second->data_)[pid]; - auto tmpPartitionSize = reinterpret_cast(pair.second->data_)[pid + 1] - reinterpret_cast(pair.second->data_)[pid]; + auto tmpPartitionOffset = reinterpret_cast(pair.second->data_)[pid]; + auto tmpPartitionSize = reinterpret_cast(pair.second->data_)[pid + 1] - reinterpret_cast(pair.second->data_)[pid]; LogsDebug(" get Partition Stream...tmpPartitionOffset %d tmpPartitionSize %d path %s", tmpPartitionOffset, tmpPartitionSize, tmpDataFilePath.c_str()); std::unique_ptr inputStream = readLocalFile(tmpDataFilePath); - int64_t targetLen = tmpPartitionSize; - int64_t seekPosit = tmpPartitionOffset; - int64_t onceReadLen = 0; + uint64_t targetLen = tmpPartitionSize; + uint64_t seekPosit = tmpPartitionOffset; + uint64_t onceReadLen = 0; while ((targetLen > 0) && bufferOutPutStream->Next(&bufferOut, &sizeOut)) { onceReadLen = targetLen > sizeOut ? sizeOut : targetLen; inputStream->read(bufferOut, onceReadLen, seekPosit); @@ -860,6 +957,38 @@ void Splitter::MergeSpilled() { partition_lengths_[pid] += flushSize; } } + + std::fill(std::begin(partition_id_cnt_cache_), std::end(partition_id_cnt_cache_), 0); + ReleaseVarcharVector(); + num_row_splited_ = 0; + cached_vectorbatch_size_ = 0; + outStream->close(); +} + +void Splitter::WriteSplit() { + for (auto pid = 0; pid < num_partitions_; ++pid) { + CacheVectorBatch(pid, true); + partition_buffer_size_[pid] = 0; //溢写之后将其清零,条件溢写需要重新分配内存 + } + + std::unique_ptr outStream = writeLocalFile(options_.data_file); + WriterOptions options; + options.setCompression(options_.compression_type); + options.setCompressionBlockSize(options_.compress_block_size); + options.setCompressionStrategy(CompressionStrategy_COMPRESSION); + std::unique_ptr streamsFactory = createStreamsFactory(options, outStream.get()); + std::unique_ptr bufferOutPutStream = streamsFactory->createStream(); + + void* bufferOut = nullptr; + int32_t sizeOut = 0; + for (auto pid = 0; pid < num_partitions_; ++ pid) { + ProtoWritePartition(pid, bufferOutPutStream, bufferOut, sizeOut); + } + + std::fill(std::begin(partition_id_cnt_cache_), std::end(partition_id_cnt_cache_), 0); + ReleaseVarcharVector(); + num_row_splited_ = 0; + cached_vectorbatch_size_ = 0; outStream->close(); } @@ -867,7 +996,7 @@ int Splitter::DeleteSpilledTmpFile() { for (auto &pair : spilled_tmp_files_info_) { auto tmpDataFilePath = pair.first + ".data"; // 释放存储有各个临时文件的偏移数据内存 - options_.allocator->free(pair.second->data_, pair.second->capacity_); + options_.allocator->Free(pair.second->data_, pair.second->capacity_); if (IsFileExist(tmpDataFilePath)) { remove(tmpDataFilePath.c_str()); } @@ -887,17 +1016,7 @@ int Splitter::SpillToTmpFile() { WriteDataFileProto(); std::shared_ptr ptrTmp = CaculateSpilledTmpFilePartitionOffsets(); spilled_tmp_files_info_[options_.next_spilled_file_dir] = ptrTmp; - - auto cache_vectorBatch_num = vectorBatch_cache_.size(); - for (uint64_t i = 0; i < cache_vectorBatch_num; ++i) { - ReleaseVectorBatch(*vectorBatch_cache_[i]); - if (nullptr == vectorBatch_cache_[i]) { - throw std::runtime_error("delete nullptr error for free vectorBatch"); - } - delete vectorBatch_cache_[i]; - vectorBatch_cache_[i] = nullptr; - } - vectorBatch_cache_.clear(); + ReleaseVarcharVector(); num_row_splited_ = 0; cached_vectorbatch_size_ = 0; return 0; @@ -956,16 +1075,16 @@ std::string Splitter::NextSpilledFileDir() { } int Splitter::Stop() { - TIME_NANO_OR_RAISE(total_spill_time_, SpillToTmpFile()); - TIME_NANO_OR_RAISE(total_write_time_, MergeSpilled()); - TIME_NANO_OR_RAISE(total_write_time_, DeleteSpilledTmpFile()); - LogsDebug(" Spill For Splitter Stopped. total_spill_row_num_: %ld ", total_spill_row_num_); + if (isSpill) { + TIME_NANO_OR_RAISE(total_write_time_, MergeSpilled()); + TIME_NANO_OR_RAISE(total_write_time_, DeleteSpilledTmpFile()); + LogsDebug(" Spill For Splitter Stopped. total_spill_row_num_: %ld ", total_spill_row_num_); + } else { + TIME_NANO_OR_RAISE(total_write_time_, WriteSplit()); + } if (nullptr == vecBatchProto) { throw std::runtime_error("delete nullptr error for free protobuf vecBatch memory"); } delete vecBatchProto; //free protobuf vecBatch memory return 0; -} - - - +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h index 3e20491ca885f34878a8198958083d02b38da043..a9d27da1ee58e52c2ae99987deb863a038c5630e 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h @@ -41,7 +41,6 @@ using namespace spark; using namespace google::protobuf::io; using namespace omniruntime::vec; using namespace omniruntime::type; -using namespace omniruntime::mem; struct SplitRowInfo { uint32_t copyedRow = 0; @@ -71,6 +70,8 @@ class Splitter { int protoSpillPartition(int32_t partition_id, std::unique_ptr &bufferStream); + int32_t ProtoWritePartition(int32_t partition_id, std::unique_ptr &bufferStream, void *bufferOut, int32_t &sizeOut); + int ComputeAndCountPartitionId(VectorBatch& vb); int AllocatePartitionBuffers(int32_t partition_id, int32_t new_size); @@ -89,6 +90,9 @@ class Splitter { void MergeSpilled(); + void WriteSplit(); + + bool isSpill = false; std::vector partition_id_; // 记录当前vb每一行的pid std::vector partition_id_cnt_cur_; // 统计不同partition记录的行数(当前处理中的vb) std::vector partition_id_cnt_cache_; // 统计不同partition记录的行数,cache住的 @@ -119,7 +123,6 @@ class Splitter { std::vector configured_dirs_; std::vector>>>> partition_cached_vectorbatch_; - std::vector vectorBatch_cache_; /* * varchar buffers: * partition_array_buffers_[partition_id][col_id][varcharBatch_id] @@ -136,6 +139,33 @@ class Splitter { std::vector partition_lengths_; private: + void ReleaseVarcharVector() + { + std::set::iterator it; + for (it = varcharVectorCache.begin(); it != varcharVectorCache.end(); it++) { + delete *it; + } + varcharVectorCache.clear(); + } + + void ReleaseVectorBatch(VectorBatch *vb) + { + int vectorCnt = vb->GetVectorCount(); + std::set vectorAddress; // vector deduplication + for (int vecIndex = 0; vecIndex < vectorCnt; vecIndex++) { + BaseVector *vector = vb->Get(vecIndex); + // not varchar vector can be released; + if (varcharVectorCache.find(vector) == varcharVectorCache.end() && + vectorAddress.find(vector) == vectorAddress.end()) { + vectorAddress.insert(vector); + delete vector; + } + } + vectorAddress.clear(); + delete vb; + } + + std::set varcharVectorCache; bool first_vector_batch_ = false; std::vector vector_batch_col_types_; InputDataTypes input_col_types; @@ -150,7 +180,7 @@ public: std::map> spilled_tmp_files_info_; - VecBatch *vecBatchProto = new VecBatch(); //protobuf 序列化对象结构 + spark::VecBatch *vecBatchProto = new VecBatch(); // protobuf 序列化对象结构 virtual int Split_Init(); diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h index 446cedc5f89988f115aedb7d9b3bc9b7c1c0a177..04d90130dea30a83651fff3526c08dc0992f9928 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h @@ -40,7 +40,7 @@ struct SplitOptions { int64_t thread_id = -1; int64_t task_attempt_id = -1; - BaseAllocator *allocator = omniruntime::mem::GetProcessRootAllocator(); + Allocator *allocator = Allocator::GetAllocator(); uint64_t spill_batch_row_num = 4096; // default value uint64_t spill_mem_threshold = 1024 * 1024 * 1024; // default value diff --git a/omnioperator/omniop-spark-extension/cpp/src/tablescan/ParquetReader.cpp b/omnioperator/omniop-spark-extension/cpp/src/tablescan/ParquetReader.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4f917e22c1d51d4e15ed0fa2a861408441a9d040 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/tablescan/ParquetReader.cpp @@ -0,0 +1,298 @@ +/** + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "jni/jni_common.h" +#include "ParquetReader.h" + +using namespace omniruntime::vec; +using namespace omniruntime::type; +using namespace arrow; +using namespace parquet::arrow; +using namespace arrow::compute; +using namespace spark::reader; + +static std::mutex mutex_; +static std::map restore_filesysptr; +static constexpr int32_t PARQUET_MAX_DECIMAL64_DIGITS = 18; +static constexpr int32_t INT128_BYTES = 16; +static constexpr int32_t INT64_BYTES = 8; +static constexpr int32_t BYTE_BITS = 8; +static constexpr int32_t LOCAL_FILE_PREFIX = 5; +static const std::string LOCAL_FILE = "file:"; +static const std::string HDFS_FILE = "hdfs:"; + +std::string spark::reader::GetFileSystemKey(std::string& path, std::string& ugi) +{ + // if the local file, all the files are the same key "file:" + std::string result = ugi; + + // if the hdfs file, only get the ip and port just like the ugi + ip + port as key + if (path.substr(0, LOCAL_FILE_PREFIX) == HDFS_FILE) { + auto mid = path.find(":", LOCAL_FILE_PREFIX); + auto end = path.find("/", mid); + std::string s1 = path.substr(LOCAL_FILE_PREFIX, mid - LOCAL_FILE_PREFIX); + std::string s2 = path.substr(mid + 1, end - (mid + 1)); + result += s1 + ":" + s2; + return result; + } + + // if the local file, get the ugi + "file" as the key + if (path.substr(0, LOCAL_FILE_PREFIX) == LOCAL_FILE) { + // process the path "file://" head, the arrow could not read the head + path = path.substr(LOCAL_FILE_PREFIX); + result += "file:"; + return result; + } + + // if not the local, not the hdfs, get the ugi + path as the key + result += path; + return result; +} + +Filesystem* spark::reader::GetFileSystemPtr(std::string& path, std::string& ugi, arrow::Status &status) +{ + auto key = GetFileSystemKey(path, ugi); + + // if not find key, create the filesystem ptr + auto iter = restore_filesysptr.find(key); + if (iter == restore_filesysptr.end()) { + Filesystem* fs = new Filesystem(); + auto result = fs::FileSystemFromUriOrPath(path); + status = result.status(); + if (!status.ok()) { + return nullptr; + } + fs->filesys_ptr = std::move(result).ValueUnsafe(); + restore_filesysptr[key] = fs; + } + + return restore_filesysptr[key]; +} + +Status ParquetReader::InitRecordReader(std::string& filePath, int64_t capacity, + const std::vector& row_group_indices, const std::vector& column_indices, std::string& ugi) +{ + arrow::MemoryPool* pool = default_memory_pool(); + + // Configure reader settings + auto reader_properties = parquet::ReaderProperties(pool); + + // Configure Arrow-specific reader settings + auto arrow_reader_properties = parquet::ArrowReaderProperties(); + arrow_reader_properties.set_batch_size(capacity); + + // Get the file from filesystem + Status result; + mutex_.lock(); + Filesystem* fs = GetFileSystemPtr(filePath, ugi, result); + mutex_.unlock(); + if (fs == nullptr || fs->filesys_ptr == nullptr) { + return Status::IOError(result); + } + ARROW_ASSIGN_OR_RAISE(auto file, fs->filesys_ptr->OpenInputFile(filePath)); + + FileReaderBuilder reader_builder; + ARROW_RETURN_NOT_OK(reader_builder.Open(file, reader_properties)); + reader_builder.memory_pool(pool); + reader_builder.properties(arrow_reader_properties); + + ARROW_ASSIGN_OR_RAISE(arrow_reader, reader_builder.Build()); + ARROW_RETURN_NOT_OK(arrow_reader->GetRecordBatchReader(row_group_indices, column_indices, &rb_reader)); + return arrow::Status::OK(); +} + +Status ParquetReader::ReadNextBatch(std::shared_ptr *batch) +{ + ARROW_RETURN_NOT_OK(rb_reader->ReadNext(batch)); + return arrow::Status::OK(); +} + +/** + * For BooleanType, copy values one by one. + */ +uint64_t CopyBooleanType(std::shared_ptr array) +{ + arrow::BooleanArray *lvb = dynamic_cast(array.get()); + auto numElements = lvb->length(); + auto originalVector = new Vector(numElements); + for (int64_t i = 0; i < numElements; i++) { + if (lvb->IsNull(i)) { + originalVector->SetNull(i); + } else { + if (lvb->Value(i)) { + originalVector->SetValue(i, true); + } else { + originalVector->SetValue(i, false); + } + } + } + return (uint64_t)originalVector; +} + +/** + * For int16/int32/int64/double type, copy values in batches and skip setNull if there is no nulls. + */ +template uint64_t CopyFixedWidth(std::shared_ptr array) +{ + using T = typename NativeType::type; + PARQUET_TYPE *lvb = dynamic_cast(array.get()); + auto numElements = lvb->length(); + auto values = lvb->raw_values(); + auto originalVector = new Vector(numElements); + // Check ColumnVectorBatch has null or not firstly + if (lvb->null_count() != 0) { + for (int64_t i = 0; i < numElements; i++) { + if (lvb->IsNull(i)) { + originalVector->SetNull(i); + } + } + } + originalVector->SetValues(0, values, numElements); + return (uint64_t)originalVector; +} + +uint64_t CopyVarWidth(std::shared_ptr array) +{ + auto lvb = dynamic_cast(array.get()); + auto numElements = lvb->length(); + auto originalVector = new Vector>(numElements); + for (int64_t i = 0; i < numElements; i++) { + if (lvb->IsValid(i)) { + auto data = lvb->GetView(i); + originalVector->SetValue(i, data); + } else { + originalVector->SetNull(i); + } + } + return (uint64_t)originalVector; +} + +uint64_t CopyToOmniDecimal128Vec(std::shared_ptr array) +{ + auto lvb = dynamic_cast(array.get()); + auto numElements = lvb->length(); + auto originalVector = new Vector(numElements); + for (int64_t i = 0; i < numElements; i++) { + if (lvb->IsValid(i)) { + auto data = lvb->GetValue(i); + __int128_t val; + memcpy_s(&val, sizeof(val), data, INT128_BYTES); + omniruntime::type::Decimal128 d128(val); + originalVector->SetValue(i, d128); + } else { + originalVector->SetNull(i); + } + } + return (uint64_t)originalVector; +} + +uint64_t CopyToOmniDecimal64Vec(std::shared_ptr array) +{ + auto lvb = dynamic_cast(array.get()); + auto numElements = lvb->length(); + auto originalVector = new Vector(numElements); + for (int64_t i = 0; i < numElements; i++) { + if (lvb->IsValid(i)) { + auto data = lvb->GetValue(i); + int64_t val; + memcpy_s(&val, sizeof(val), data, INT64_BYTES); + originalVector->SetValue(i, val); + } else { + originalVector->SetNull(i); + } + } + return (uint64_t)originalVector; +} + +int spark::reader::CopyToOmniVec(std::shared_ptr vcType, int &omniTypeId, uint64_t &omniVecId, + std::shared_ptr array) +{ + switch (vcType->id()) { + case arrow::Type::BOOL: + omniTypeId = static_cast(OMNI_BOOLEAN); + omniVecId = CopyBooleanType(array); + break; + case arrow::Type::INT16: + omniTypeId = static_cast(OMNI_SHORT); + omniVecId = CopyFixedWidth(array); + break; + case arrow::Type::INT32: + omniTypeId = static_cast(OMNI_INT); + omniVecId = CopyFixedWidth(array); + break; + case arrow::Type::DATE32: + omniTypeId = static_cast(OMNI_DATE32); + omniVecId = CopyFixedWidth(array); + break; + case arrow::Type::INT64: + omniTypeId = static_cast(OMNI_LONG); + omniVecId = CopyFixedWidth(array); + break; + case arrow::Type::DATE64: + omniTypeId = static_cast(OMNI_DATE64); + omniVecId = CopyFixedWidth(array); + break; + case arrow::Type::DOUBLE: + omniTypeId = static_cast(OMNI_DOUBLE); + omniVecId = CopyFixedWidth(array); + break; + case arrow::Type::STRING: + omniTypeId = static_cast(OMNI_VARCHAR); + omniVecId = CopyVarWidth(array); + break; + case arrow::Type::DECIMAL128: { + auto decimalType = static_cast(vcType.get()); + if (decimalType->precision() > PARQUET_MAX_DECIMAL64_DIGITS) { + omniTypeId = static_cast(OMNI_DECIMAL128); + omniVecId = CopyToOmniDecimal128Vec(array); + } else { + omniTypeId = static_cast(OMNI_DECIMAL64); + omniVecId = CopyToOmniDecimal64Vec(array); + } + break; + } + default: { + throw std::runtime_error("Native ColumnarFileScan Not support For This Type: " + vcType->id()); + } + } + return 1; +} + +std::pair spark::reader::TransferToOmniVecs(std::shared_ptr batch) +{ + int64_t num_columns = batch->num_columns(); + std::vector> fields = batch->schema()->fields(); + auto vecTypes = new int64_t[num_columns]; + auto vecs = new int64_t[num_columns]; + for (int64_t colIdx = 0; colIdx < num_columns; colIdx++) { + std::shared_ptr array = batch->column(colIdx); + // One array in current batch + std::shared_ptr data = array->data(); + int omniTypeId = 0; + uint64_t omniVecId = 0; + spark::reader::CopyToOmniVec(data->type, omniTypeId, omniVecId, array); + vecTypes[colIdx] = omniTypeId; + vecs[colIdx] = omniVecId; + } + return std::make_pair(vecTypes, vecs); +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/tablescan/ParquetReader.h b/omnioperator/omniop-spark-extension/cpp/src/tablescan/ParquetReader.h new file mode 100644 index 0000000000000000000000000000000000000000..8fef9d495801bd51f5c04d9ddf86f241a76715d0 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/tablescan/ParquetReader.h @@ -0,0 +1,72 @@ +/** + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_THESTRAL_PLUGIN_PARQUETREADER_H +#define SPARK_THESTRAL_PLUGIN_PARQUETREADER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace spark::reader { + class ParquetReader { + public: + ParquetReader() {} + + arrow::Status InitRecordReader(std::string& path, int64_t capacity, + const std::vector& row_group_indices, const std::vector& column_indices, std::string& ugi); + + arrow::Status ReadNextBatch(std::shared_ptr *batch); + + std::unique_ptr arrow_reader; + + std::shared_ptr rb_reader; + }; + + class Filesystem { + public: + Filesystem() {} + + /** + * File system holds the hdfs client, which should outlive the RecordBatchReader. + */ + std::shared_ptr filesys_ptr; + }; + + std::string GetFileSystemKey(std::string& path, std::string& ugi); + + Filesystem* GetFileSystemPtr(std::string& path, std::string& ugi, arrow::Status &status); + + int CopyToOmniVec(std::shared_ptr vcType, int &omniTypeId, uint64_t &omniVecId, + std::shared_ptr array); + + std::pair TransferToOmniVecs(std::shared_ptr batch); +} +#endif // SPARK_THESTRAL_PLUGIN_PARQUETREADER_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt index 13fd8a376bec0ad3eaa917109f7f98b0f772cca7..ba1ad3a773c35a101cf728f00a19ba30b0dae607 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt @@ -29,8 +29,7 @@ target_link_libraries(${TP_TEST_TARGET} pthread stdc++ dl - boostkit-omniop-runtime-1.1.0-aarch64 - boostkit-omniop-vector-1.1.0-aarch64 + boostkit-omniop-vector-1.3.0-aarch64 securec spark_columnar_plugin) diff --git a/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp b/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp index 1834345d54466d8e65f34eaea4ba2c99396440e0..3031943eeae22b591de4c4b3693eb1e1744b3ac3 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp +++ b/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp @@ -39,6 +39,7 @@ protected: if (IsFileExist(tmpTestingDir)) { DeletePathAll(tmpTestingDir.c_str()); } + testShuffleSplitterHolder.Clear(); } // run before each case... @@ -242,7 +243,7 @@ TEST_F (ShuffleTest, Split_Short_10WRows) { 0, tmpTestingDir); for (uint64_t j = 0; j < 100; j++) { - VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 1000, OMNI_SHORT); + VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 1000, ShortType()); Test_splitter_split(splitterId, vb); } Test_splitter_stop(splitterId); @@ -270,7 +271,7 @@ TEST_F (ShuffleTest, Split_Boolean_10WRows) { 0, tmpTestingDir); for (uint64_t j = 0; j < 100; j++) { - VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 1000, OMNI_BOOLEAN); + VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 1000, BooleanType()); Test_splitter_split(splitterId, vb); } Test_splitter_stop(splitterId); @@ -298,7 +299,7 @@ TEST_F (ShuffleTest, Split_Long_100WRows) { 0, tmpTestingDir); for (uint64_t j = 0; j < 100; j++) { - VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 10000, OMNI_LONG); + VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 10000, LongType()); Test_splitter_split(splitterId, vb); } Test_splitter_stop(splitterId); diff --git a/omnioperator/omniop-spark-extension/cpp/test/tablescan/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/test/tablescan/CMakeLists.txt index 8ca2b6d593af0afccf671d38947c175f33fab157..0f026d752ed3fbecd746a7ac8cd85e730f7076f4 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/tablescan/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/test/tablescan/CMakeLists.txt @@ -3,7 +3,7 @@ configure_file(scan_test.h.in ${CMAKE_CURRENT_SOURCE_DIR}/scan_test.h) aux_source_directory(${CMAKE_CURRENT_LIST_DIR} SCAN_TESTS_LIST) set(SCAN_TEST_TARGET tablescantest) -add_library(${SCAN_TEST_TARGET} STATIC ${SCAN_TESTS_LIST}) +add_library(${SCAN_TEST_TARGET} STATIC ${SCAN_TESTS_LIST} parquet_scan_test.cpp) target_compile_options(${SCAN_TEST_TARGET} PUBLIC ) target_include_directories(${SCAN_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include) diff --git a/omnioperator/omniop-spark-extension/cpp/test/tablescan/parquet_scan_test.cpp b/omnioperator/omniop-spark-extension/cpp/test/tablescan/parquet_scan_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a7da7f0ff79da724350cc3bbc3f62fcff68b948b --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/test/tablescan/parquet_scan_test.cpp @@ -0,0 +1,128 @@ +/** + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "scan_test.h" +#include "tablescan/ParquetReader.h" + +using namespace spark::reader; +using namespace arrow; +using namespace omniruntime::vec; + +/* + * CREATE TABLE `parquet_test` ( `c1` int, `c2` varChar(60), `c3` string, `c4` bigint, + * `c5` char(40), `c6` float, `c7` double, `c8` decimal(9,8), `c9` decimal(18,5), + * `c10` boolean, `c11` smallint, `c12` timestamp, `c13` date)stored as parquet; + * + * insert into `parquet_test` values (10, "varchar_1", "string_type_1", 10000, "char_1", + * 11.11, 1111.1111, null 131.11110, true, 11, '2021-11-30 17:00:11', '2021-12-01'); + */ +TEST(read, test_parquet_reader) +{ + std::string filename = "/resources/parquet_data_all_type"; + filename = PROJECT_PATH + filename; + const std::vector row_group_indices = {0}; + const std::vector column_indices = {0, 1, 3, 6, 7, 8, 9, 10, 12}; + + ParquetReader *reader = new ParquetReader(); + std::string ugi = "root@sample"; + auto state1 = reader->InitRecordReader(filename, 1024, row_group_indices, column_indices, ugi); + ASSERT_EQ(state1, Status::OK()); + + std::shared_ptr batch; + auto state2 = reader->ReadNextBatch(&batch); + ASSERT_EQ(state2, Status::OK()); + std::cout << "num_rows: " << batch->num_rows() << std::endl; + std::cout << "num_columns: " << batch->num_columns() << std::endl; + std::cout << "Print: " << batch->ToString() << std::endl; + auto pair = TransferToOmniVecs(batch); + + BaseVector *intVector = reinterpret_cast(pair.second[0]); + auto int_result = static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(intVector)); + ASSERT_EQ(*int_result, 10); + + auto varCharVector = reinterpret_cast> *>(pair.second[1]); + std::string str_expected = "varchar_1"; + ASSERT_TRUE(str_expected == varCharVector->GetValue(0)); + + BaseVector *longVector = reinterpret_cast(pair.second[2]); + auto long_result = static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(longVector)); + ASSERT_EQ(*long_result, 10000); + + BaseVector *doubleVector = reinterpret_cast(pair.second[3]); + auto double_result = static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(doubleVector)); + ASSERT_EQ(*double_result, 1111.1111); + + BaseVector *nullVector = reinterpret_cast(pair.second[4]); + ASSERT_TRUE(nullVector->IsNull(0)); + + BaseVector *decimal64Vector = reinterpret_cast(pair.second[5]); + auto decimal64_result = static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(decimal64Vector)); + ASSERT_EQ(*decimal64_result, 13111110); + + BaseVector *booleanVector = reinterpret_cast(pair.second[6]); + auto boolean_result = static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(booleanVector)); + ASSERT_EQ(*boolean_result, true); + + BaseVector *smallintVector = reinterpret_cast(pair.second[7]); + auto smallint_result = static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(smallintVector)); + ASSERT_EQ(*smallint_result, 11); + + BaseVector *dateVector = reinterpret_cast(pair.second[8]); + auto date_result = static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(dateVector)); + omniruntime::type::Date32 date32(*date_result); + char chars[11]; + date32.ToString(chars, 11); + std::string date_expected(chars); + ASSERT_TRUE(date_expected == "2021-12-01"); + + delete reader; + delete intVector; + delete varCharVector; + delete longVector; + delete doubleVector; + delete nullVector; + delete decimal64Vector; + delete booleanVector; + delete smallintVector; + delete dateVector; +} + +TEST(read, test_decimal128_copy) +{ + auto decimal_type = arrow::decimal(20, 1); + arrow::Decimal128Builder builder(decimal_type); + arrow::Decimal128 value(20230420); + auto s1 = builder.Append(value); + std::shared_ptr array; + auto s2 = builder.Finish(&array); + + int omniTypeId = 0; + uint64_t omniVecId = 0; + spark::reader::CopyToOmniVec(decimal_type, omniTypeId, omniVecId, array); + + BaseVector *decimal128Vector = reinterpret_cast(omniVecId); + auto decimal128_result = + static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(decimal128Vector)); + ASSERT_TRUE((*decimal128_result).ToString() == "20230420"); + + delete decimal128Vector; +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/test/tablescan/resources/parquet_data_all_type b/omnioperator/omniop-spark-extension/cpp/test/tablescan/resources/parquet_data_all_type new file mode 100644 index 0000000000000000000000000000000000000000..3de6f3c8954f05f496f6211a813034462ae384a6 Binary files /dev/null and b/omnioperator/omniop-spark-extension/cpp/test/tablescan/resources/parquet_data_all_type differ diff --git a/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.cpp b/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.cpp index f8a6a6b7f2776f212d7ba6b1c9ee8d9260509116..2ed604e50420c402e9184c0a4011f66d69c00158 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.cpp +++ b/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.cpp @@ -17,15 +17,13 @@ * limitations under the License. */ -#include "gtest/gtest.h" -#include -#include -#include "../../src/jni/OrcColumnarBatchJniReader.h" +#include +#include +#include +#include "jni/OrcColumnarBatchJniReader.h" #include "scan_test.h" -#include "orc/sargs/SearchArgument.hh" static std::string filename = "/resources/orc_data_all_type"; -static orc::ColumnVectorBatch *batchPtr; static orc::StructVectorBatch *root; /* @@ -53,17 +51,24 @@ protected: orc::ReaderOptions readerOpts; orc::RowReaderOptions rowReaderOptions; std::unique_ptr reader = orc::createReader(orc::readFile(PROJECT_PATH + filename), readerOpts); - std::unique_ptr rowReader = reader->createRowReader(); + rowReader = reader->createRowReader().release(); std::unique_ptr batch = rowReader->createRowBatch(4096); rowReader->next(*batch); - batchPtr = batch.release(); - root = static_cast(batchPtr); + types = &(rowReader->getSelectedType()); + root = static_cast(batch.release()); } // run after each case... virtual void TearDown() override { - delete batchPtr; + delete root; + root = nullptr; + types = nullptr; + delete rowReader; + rowReader = nullptr; } + + const orc::Type *types; + orc::RowReader *rowReader; }; TEST_F(ScanTest, test_literal_get_long) @@ -71,11 +76,11 @@ TEST_F(ScanTest, test_literal_get_long) orc::Literal tmpLit(0L); // test get long - getLiteral(tmpLit, (int)(orc::PredicateDataType::LONG), "655361"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::LONG), "655361"); ASSERT_EQ(tmpLit.getLong(), 655361); - getLiteral(tmpLit, (int)(orc::PredicateDataType::LONG), "-655361"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::LONG), "-655361"); ASSERT_EQ(tmpLit.getLong(), -655361); - getLiteral(tmpLit, (int)(orc::PredicateDataType::LONG), "0"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::LONG), "0"); ASSERT_EQ(tmpLit.getLong(), 0); } @@ -84,11 +89,11 @@ TEST_F(ScanTest, test_literal_get_float) orc::Literal tmpLit(0L); // test get float - getLiteral(tmpLit, (int)(orc::PredicateDataType::FLOAT), "12345.6789"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::FLOAT), "12345.6789"); ASSERT_EQ(tmpLit.getFloat(), 12345.6789); - getLiteral(tmpLit, (int)(orc::PredicateDataType::FLOAT), "-12345.6789"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::FLOAT), "-12345.6789"); ASSERT_EQ(tmpLit.getFloat(), -12345.6789); - getLiteral(tmpLit, (int)(orc::PredicateDataType::FLOAT), "0"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::FLOAT), "0"); ASSERT_EQ(tmpLit.getFloat(), 0); } @@ -97,9 +102,9 @@ TEST_F(ScanTest, test_literal_get_string) orc::Literal tmpLit(0L); // test get string - getLiteral(tmpLit, (int)(orc::PredicateDataType::STRING), "testStringForLit"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::STRING), "testStringForLit"); ASSERT_EQ(tmpLit.getString(), "testStringForLit"); - getLiteral(tmpLit, (int)(orc::PredicateDataType::STRING), ""); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::STRING), ""); ASSERT_EQ(tmpLit.getString(), ""); } @@ -108,7 +113,7 @@ TEST_F(ScanTest, test_literal_get_date) orc::Literal tmpLit(0L); // test get date - getLiteral(tmpLit, (int)(orc::PredicateDataType::DATE), "987654321"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::DATE), "987654321"); ASSERT_EQ(tmpLit.getDate(), 987654321); } @@ -117,15 +122,15 @@ TEST_F(ScanTest, test_literal_get_decimal) orc::Literal tmpLit(0L); // test get decimal - getLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "199999999999998.998000 22 6"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "199999999999998.998000 22 6"); ASSERT_EQ(tmpLit.getDecimal().toString(), "199999999999998.998000"); - getLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "10.998000 10 6"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "10.998000 10 6"); ASSERT_EQ(tmpLit.getDecimal().toString(), "10.998000"); - getLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "-10.998000 10 6"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "-10.998000 10 6"); ASSERT_EQ(tmpLit.getDecimal().toString(), "-10.998000"); - getLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "9999.999999 10 6"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "9999.999999 10 6"); ASSERT_EQ(tmpLit.getDecimal().toString(), "9999.999999"); - getLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "-0.000000 10 6"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::DECIMAL), "-0.000000 10 6"); ASSERT_EQ(tmpLit.getDecimal().toString(), "0.000000"); } @@ -134,17 +139,17 @@ TEST_F(ScanTest, test_literal_get_bool) orc::Literal tmpLit(0L); // test get bool - getLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "true"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "true"); ASSERT_EQ(tmpLit.getBool(), true); - getLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "True"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "True"); ASSERT_EQ(tmpLit.getBool(), true); - getLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "false"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "false"); ASSERT_EQ(tmpLit.getBool(), false); - getLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "False"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "False"); ASSERT_EQ(tmpLit.getBool(), false); std::string tmpStr = ""; try { - getLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "exception"); + GetLiteral(tmpLit, (int)(orc::PredicateDataType::BOOLEAN), "exception"); } catch (std::exception &e) { tmpStr = e.what(); } @@ -156,9 +161,9 @@ TEST_F(ScanTest, test_copy_intVec) int omniType = 0; uint64_t omniVecId = 0; // int type - copyToOmniVec(orc::TypeKind::INT, omniType, omniVecId, root->fields[0]); + CopyToOmniVec(types->getSubtype(0), omniType, omniVecId, root->fields[0]); ASSERT_EQ(omniType, omniruntime::type::OMNI_INT); - omniruntime::vec::IntVector *olbInt = (omniruntime::vec::IntVector *)(omniVecId); + auto *olbInt = (omniruntime::vec::Vector *)(omniVecId); ASSERT_EQ(olbInt->GetValue(0), 10); delete olbInt; } @@ -168,12 +173,11 @@ TEST_F(ScanTest, test_copy_varCharVec) int omniType = 0; uint64_t omniVecId = 0; // varchar type - copyToOmniVec(orc::TypeKind::VARCHAR, omniType, omniVecId, root->fields[1], 60); + CopyToOmniVec(types->getSubtype(1), omniType, omniVecId, root->fields[1]); ASSERT_EQ(omniType, omniruntime::type::OMNI_VARCHAR); - uint8_t *actualChar = nullptr; - omniruntime::vec::VarcharVector *olbVc = (omniruntime::vec::VarcharVector *)(omniVecId); - int len = olbVc->GetValue(0, &actualChar); - std::string actualStr(reinterpret_cast(actualChar), 0, len); + auto *olbVc = (omniruntime::vec::Vector> *)( + omniVecId); + std::string_view actualStr = olbVc->GetValue(0); ASSERT_EQ(actualStr, "varchar_1"); delete olbVc; } @@ -182,14 +186,13 @@ TEST_F(ScanTest, test_copy_stringVec) { int omniType = 0; uint64_t omniVecId = 0; - uint8_t *actualChar = nullptr; // string type - copyToOmniVec(orc::TypeKind::STRING, omniType, omniVecId, root->fields[2]); + CopyToOmniVec(types->getSubtype(2), omniType, omniVecId, root->fields[2]); ASSERT_EQ(omniType, omniruntime::type::OMNI_VARCHAR); - omniruntime::vec::VarcharVector *olbStr = (omniruntime::vec::VarcharVector *)(omniVecId); - int len = olbStr->GetValue(0, &actualChar); - std::string actualStr2(reinterpret_cast(actualChar), 0, len); - ASSERT_EQ(actualStr2, "string_type_1"); + auto *olbStr = (omniruntime::vec::Vector> *)( + omniVecId); + std::string_view actualStr = olbStr->GetValue(0); + ASSERT_EQ(actualStr, "string_type_1"); delete olbStr; } @@ -198,9 +201,9 @@ TEST_F(ScanTest, test_copy_longVec) int omniType = 0; uint64_t omniVecId = 0; // bigint type - copyToOmniVec(orc::TypeKind::LONG, omniType, omniVecId, root->fields[3]); + CopyToOmniVec(types->getSubtype(3), omniType, omniVecId, root->fields[3]); ASSERT_EQ(omniType, omniruntime::type::OMNI_LONG); - omniruntime::vec::LongVector *olbLong = (omniruntime::vec::LongVector *)(omniVecId); + auto *olbLong = (omniruntime::vec::Vector *)(omniVecId); ASSERT_EQ(olbLong->GetValue(0), 10000); delete olbLong; } @@ -209,15 +212,14 @@ TEST_F(ScanTest, test_copy_charVec) { int omniType = 0; uint64_t omniVecId = 0; - uint8_t *actualChar = nullptr; // char type - copyToOmniVec(orc::TypeKind::CHAR, omniType, omniVecId, root->fields[4], 40); + CopyToOmniVec(types->getSubtype(4), omniType, omniVecId, root->fields[4]); ASSERT_EQ(omniType, omniruntime::type::OMNI_VARCHAR); - omniruntime::vec::VarcharVector *olbChar40 = (omniruntime::vec::VarcharVector *)(omniVecId); - int len = olbChar40->GetValue(0, &actualChar); - std::string actualStr3(reinterpret_cast(actualChar), 0, len); - ASSERT_EQ(actualStr3, "char_1"); - delete olbChar40; + auto *olbChar = (omniruntime::vec::Vector> *)( + omniVecId); + std::string_view actualStr = olbChar->GetValue(0); + ASSERT_EQ(actualStr, "char_1"); + delete olbChar; } TEST_F(ScanTest, test_copy_doubleVec) @@ -225,9 +227,9 @@ TEST_F(ScanTest, test_copy_doubleVec) int omniType = 0; uint64_t omniVecId = 0; // double type - copyToOmniVec(orc::TypeKind::DOUBLE, omniType, omniVecId, root->fields[6]); + CopyToOmniVec(types->getSubtype(6), omniType, omniVecId, root->fields[6]); ASSERT_EQ(omniType, omniruntime::type::OMNI_DOUBLE); - omniruntime::vec::DoubleVector *olbDouble = (omniruntime::vec::DoubleVector *)(omniVecId); + auto *olbDouble = (omniruntime::vec::Vector *)(omniVecId); ASSERT_EQ(olbDouble->GetValue(0), 1111.1111); delete olbDouble; } @@ -237,9 +239,9 @@ TEST_F(ScanTest, test_copy_booleanVec) int omniType = 0; uint64_t omniVecId = 0; // boolean type - copyToOmniVec(orc::TypeKind::BOOLEAN, omniType, omniVecId, root->fields[9]); + CopyToOmniVec(types->getSubtype(9), omniType, omniVecId, root->fields[9]); ASSERT_EQ(omniType, omniruntime::type::OMNI_BOOLEAN); - omniruntime::vec::BooleanVector *olbBoolean = (omniruntime::vec::BooleanVector *)(omniVecId); + auto *olbBoolean = (omniruntime::vec::Vector *)(omniVecId); ASSERT_EQ(olbBoolean->GetValue(0), true); delete olbBoolean; } @@ -249,9 +251,9 @@ TEST_F(ScanTest, test_copy_shortVec) int omniType = 0; uint64_t omniVecId = 0; // short type - copyToOmniVec(orc::TypeKind::SHORT, omniType, omniVecId, root->fields[10]); + CopyToOmniVec(types->getSubtype(10), omniType, omniVecId, root->fields[10]); ASSERT_EQ(omniType, omniruntime::type::OMNI_SHORT); - omniruntime::vec::ShortVector *olbShort = (omniruntime::vec::ShortVector *)(omniVecId); + auto *olbShort = (omniruntime::vec::Vector *)(omniVecId); ASSERT_EQ(olbShort->GetValue(0), 11); delete olbShort; } @@ -265,24 +267,26 @@ TEST_F(ScanTest, test_build_leafs) orc::Literal lit(100L); // test EQUALS - buildLeaves(PredicateOperatorType::EQUALS, litList, lit, "leaf-0", orc::PredicateDataType::LONG, *builder); + BuildLeaves(PredicateOperatorType::EQUALS, litList, lit, "leaf-0", orc::PredicateDataType::LONG, *builder); // test LESS_THAN - buildLeaves(PredicateOperatorType::LESS_THAN, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); + BuildLeaves(PredicateOperatorType::LESS_THAN, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); // test LESS_THAN_EQUALS - buildLeaves(PredicateOperatorType::LESS_THAN_EQUALS, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); + BuildLeaves(PredicateOperatorType::LESS_THAN_EQUALS, litList, lit, "leaf-1", orc::PredicateDataType::LONG, + *builder); // test NULL_SAFE_EQUALS - buildLeaves(PredicateOperatorType::NULL_SAFE_EQUALS, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); + BuildLeaves(PredicateOperatorType::NULL_SAFE_EQUALS, litList, lit, "leaf-1", orc::PredicateDataType::LONG, + *builder); // test IS_NULL - buildLeaves(PredicateOperatorType::IS_NULL, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); + BuildLeaves(PredicateOperatorType::IS_NULL, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); // test BETWEEN std::string tmpStr = ""; try { - buildLeaves(PredicateOperatorType::BETWEEN, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); + BuildLeaves(PredicateOperatorType::BETWEEN, litList, lit, "leaf-1", orc::PredicateDataType::LONG, *builder); } catch (std::exception &e) { tmpStr = e.what(); } diff --git a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp index 586f4bbdb95721b22422d715f645eb502dc1a894..9010cf1504f70eb66fde51c72f60a39318320214 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp +++ b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp @@ -21,199 +21,35 @@ using namespace omniruntime::vec; -void ToVectorTypes(const int32_t *dataTypeIds, int32_t dataTypeCount, std::vector &dataTypes) -{ - for (int i = 0; i < dataTypeCount; ++i) { - if (dataTypeIds[i] == OMNI_VARCHAR) { - dataTypes.push_back(std::make_shared(50)); - continue; - } else if (dataTypeIds[i] == OMNI_CHAR) { - dataTypes.push_back(std::make_shared(50)); - continue; - } - dataTypes.push_back(std::make_shared(dataTypeIds[i])); - } -} - -VectorBatch* CreateInputData(const int32_t numRows, - const int32_t numCols, - int32_t* inputTypeIds, - int64_t* allData) -{ - auto *vecBatch = new VectorBatch(numCols, numRows); - vector inputTypes; - ToVectorTypes(inputTypeIds, numCols, inputTypes); - vecBatch->NewVectors(omniruntime::vec::GetProcessGlobalVecAllocator(), inputTypes); - for (int i = 0; i < numCols; ++i) { - switch (inputTypeIds[i]) { - case OMNI_BOOLEAN: - ((BooleanVector *)vecBatch->GetVector(i))->SetValues(0, (int32_t *)allData[i], numRows); - break; - case OMNI_INT: - ((IntVector *)vecBatch->GetVector(i))->SetValues(0, (int32_t *)allData[i], numRows); - break; - case OMNI_LONG: - ((LongVector *)vecBatch->GetVector(i))->SetValues(0, (int64_t *)allData[i], numRows); - break; - case OMNI_DOUBLE: - ((DoubleVector *)vecBatch->GetVector(i))->SetValues(0, (double *)allData[i], numRows); - break; - case OMNI_SHORT: - ((ShortVector *)vecBatch->GetVector(i))->SetValues(0, (int16_t *)allData[i], numRows); - break; - case OMNI_VARCHAR: - case OMNI_CHAR: { - for (int j = 0; j < numRows; ++j) { - int64_t addr = (reinterpret_cast(allData[i]))[j]; - std::string s (reinterpret_cast(addr)); - ((VarcharVector *)vecBatch->GetVector(i))->SetValue(j, (uint8_t *)(s.c_str()), s.length()); - } - break; - } - case OMNI_DECIMAL128: - ((Decimal128Vector *)vecBatch->GetVector(i))->SetValues(0, (int64_t *) allData[i], numRows); - break; - default:{ - LogError("No such data type %d", inputTypeIds[i]); - } - } - } - return vecBatch; -} - -VarcharVector *CreateVarcharVector(VarcharDataType type, std::string *values, int32_t length) +VectorBatch *CreateVectorBatch(const DataTypes &types, int32_t rowCount, ...) { - VectorAllocator * vecAllocator = omniruntime::vec::GetProcessGlobalVecAllocator(); - uint32_t width = type.GetWidth(); - VarcharVector *vector = std::make_unique(vecAllocator, length * width, length).release(); - for (int32_t i = 0; i < length; i++) { - vector->SetValue(i, reinterpret_cast(values[i].c_str()), values[i].length()); + int32_t typesCount = types.GetSize(); + auto *vectorBatch = new VectorBatch(rowCount); + va_list args; + va_start(args, rowCount); + for (int32_t i = 0; i < typesCount; i++) { + DataTypePtr type = types.GetType(i); + vectorBatch->Append(CreateVector(*type, rowCount, args)); } - return vector; + va_end(args); + return vectorBatch; } -Decimal128Vector *CreateDecimal128Vector(Decimal128 *values, int32_t length) +BaseVector *CreateVector(DataType &dataType, int32_t rowCount, va_list &args) { - VectorAllocator *vecAllocator = omniruntime::vec::GetProcessGlobalVecAllocator(); - Decimal128Vector *vector = std::make_unique(vecAllocator, length).release(); - for (int32_t i = 0; i < length; i++) { - vector->SetValue(i, values[i]); - } - return vector; + return DYNAMIC_TYPE_DISPATCH(CreateFlatVector, dataType.GetId(), rowCount, args); } -Vector *CreateVector(DataType &vecType, int32_t rowCount, va_list &args) -{ - switch (vecType.GetId()) { - case OMNI_INT: - case OMNI_DATE32: - return CreateVector(va_arg(args, int32_t *), rowCount); - case OMNI_LONG: - case OMNI_DECIMAL64: - return CreateVector(va_arg(args, int64_t *), rowCount); - case OMNI_DOUBLE: - return CreateVector(va_arg(args, double *), rowCount); - case OMNI_BOOLEAN: - return CreateVector(va_arg(args, bool *), rowCount); - case OMNI_VARCHAR: - case OMNI_CHAR: - return CreateVarcharVector(static_cast(vecType), va_arg(args, std::string *), rowCount); - case OMNI_DECIMAL128: - return CreateDecimal128Vector(va_arg(args, Decimal128 *), rowCount); - default: - std::cerr << "Unsupported type : " << vecType.GetId() << std::endl; - return nullptr; - } -} - -DictionaryVector *CreateDictionaryVector(DataType &dataType, int32_t rowCount, int32_t *ids, int32_t idsCount, ...) +BaseVector *CreateDictionaryVector(DataType &dataType, int32_t rowCount, int32_t *ids, int32_t idsCount, + ...) { va_list args; va_start(args, idsCount); - Vector *dictionary = CreateVector(dataType, rowCount, args); + BaseVector *dictionary = CreateVector(dataType, rowCount, args); va_end(args); - auto vec = new DictionaryVector(dictionary, ids, idsCount); + auto dictionaryVector = DYNAMIC_TYPE_DISPATCH(CreateDictionary, dataType.GetId(), dictionary, ids, idsCount); delete dictionary; - return vec; -} - -Vector *buildVector(const DataType &aggType, int32_t rowNumber) -{ - VectorAllocator *vecAllocator = omniruntime::vec::GetProcessGlobalVecAllocator(); - switch (aggType.GetId()) { - case OMNI_NONE: { - LongVector *col = new LongVector(vecAllocator, rowNumber); - for (int32_t j = 0; j < rowNumber; ++j) { - col->SetValueNull(j); - } - return col; - } - case OMNI_INT: - case OMNI_DATE32: { - IntVector *col = new IntVector(vecAllocator, rowNumber); - for (int32_t j = 0; j < rowNumber; ++j) { - col->SetValue(j, 1); - } - return col; - } - case OMNI_LONG: - case OMNI_DECIMAL64: { - LongVector *col = new LongVector(vecAllocator, rowNumber); - for (int32_t j = 0; j < rowNumber; ++j) { - col->SetValue(j, 1); - } - return col; - } - case OMNI_DOUBLE: { - DoubleVector *col = new DoubleVector(vecAllocator, rowNumber); - for (int32_t j = 0; j < rowNumber; ++j) { - col->SetValue(j, 1); - } - return col; - } - case OMNI_BOOLEAN: { - BooleanVector *col = new BooleanVector(vecAllocator, rowNumber); - for (int32_t j = 0; j < rowNumber; ++j) { - col->SetValue(j, 1); - } - return col; - } - case OMNI_DECIMAL128: { - Decimal128Vector *col = new Decimal128Vector(vecAllocator, rowNumber); - for (int32_t j = 0; j < rowNumber; ++j) { - col->SetValue(j, Decimal128(0, 1)); - } - return col; - } - case OMNI_VARCHAR: - case OMNI_CHAR: { - VarcharDataType charType = (VarcharDataType &)aggType; - VarcharVector *col = new VarcharVector(vecAllocator, charType.GetWidth() * rowNumber, rowNumber); - for (int32_t j = 0; j < rowNumber; ++j) { - std::string str = std::to_string(j); - col->SetValue(j, reinterpret_cast(str.c_str()), str.size()); - } - return col; - } - default: { - LogError("No such %d type support", aggType.GetId()); - return nullptr; - } - } -} - -VectorBatch *CreateVectorBatch(const DataTypes &types, int32_t rowCount, ...) -{ - int32_t typesCount = types.GetSize(); - auto *vectorBatch = new VectorBatch(typesCount, rowCount); - va_list args; - va_start(args, rowCount); - for (int32_t i = 0; i < typesCount; i++) { - DataTypePtr type = types.GetType(i); - vectorBatch->SetVector(i, CreateVector(*type, rowCount, args)); - } - va_end(args); - return vectorBatch; + return dictionaryVector; } /** @@ -225,24 +61,16 @@ VectorBatch *CreateVectorBatch(const DataTypes &types, int32_t rowCount, ...) */ VectorBatch* CreateVectorBatch_1row_varchar_withPid(int pid, std::string inputString) { // gen vectorBatch - const int32_t numCols = 2; - int32_t* inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_VARCHAR; + DataTypes inputTypes(std::vector({ IntType(), VarcharType() })); const int32_t numRows = 1; auto* col1 = new int32_t[numRows]; col1[0] = pid; - auto* col2 = new int64_t[numRows]; - std::string* strTmp = new std::string(inputString); - col2[0] = (int64_t)(strTmp->c_str()); + auto* col2 = new std::string[numRows]; + col2[0] = std::move(inputString); - int64_t allData[numCols] = {reinterpret_cast(col1), - reinterpret_cast(col2)}; - VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col1, col2); delete[] col1; delete[] col2; - delete strTmp; return in; } @@ -255,224 +83,144 @@ VectorBatch* CreateVectorBatch_1row_varchar_withPid(int pid, std::string inputSt */ VectorBatch* CreateVectorBatch_4col_withPid(int parNum, int rowNum) { int partitionNum = parNum; - const int32_t numCols = 5; - int32_t* inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_INT; - inputTypes[2] = OMNI_LONG; - inputTypes[3] = OMNI_DOUBLE; - inputTypes[4] = OMNI_VARCHAR; + DataTypes inputTypes(std::vector({ IntType(), IntType(), LongType(), DoubleType(), VarcharType() })); const int32_t numRows = rowNum; auto* col0 = new int32_t[numRows]; auto* col1 = new int32_t[numRows]; auto* col2 = new int64_t[numRows]; auto* col3 = new double[numRows]; - auto* col4 = new int64_t[numRows]; - string startStr = "_START_"; - string endStr = "_END_"; + auto* col4 = new std::string[numRows]; + std::string startStr = "_START_"; + std::string endStr = "_END_"; std::vector string_cache_test_; for (int i = 0; i < numRows; i++) { - col0[i] = (i+1) % partitionNum; + col0[i] = (i + 1) % partitionNum; col1[i] = i + 1; col2[i] = i + 1; col3[i] = i + 1; - std::string* strTmp = new std::string(startStr + to_string(i + 1) + endStr); - string_cache_test_.push_back(strTmp); - col4[i] = (int64_t)((*strTmp).c_str()); + std::string strTmp = std::string(startStr + to_string(i + 1) + endStr); + col4[i] = std::move(strTmp); } - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2), - reinterpret_cast(col3), - reinterpret_cast(col4)}; - VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2, col3, col4); delete[] col0; delete[] col1; delete[] col2; delete[] col3; delete[] col4; - - for (uint p = 0; p < string_cache_test_.size(); p++) { - delete string_cache_test_[p]; // release memory - } return in; } -VectorBatch* CreateVectorBatch_1FixCol_withPid(int parNum, int rowNum, int32_t fixColType) { +VectorBatch* CreateVectorBatch_1FixCol_withPid(int parNum, int rowNum, DataTypePtr fixColType) { int partitionNum = parNum; - const int32_t numCols = 2; - int32_t* inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = fixColType; + DataTypes inputTypes(std::vector({ IntType(), std::move(fixColType) })); const int32_t numRows = rowNum; auto* col0 = new int32_t[numRows]; auto* col1 = new int64_t[numRows]; for (int i = 0; i < numRows; i++) { - col0[i] = (i+1) % partitionNum; + col0[i] = (i + 1) % partitionNum; col1[i] = i + 1; } - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1)}; - VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1); delete[] col0; delete[] col1; return in; } VectorBatch* CreateVectorBatch_2column_1row_withPid(int pid, std::string strVar, int intVar) { - const int32_t numCols = 3; - int32_t* inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_VARCHAR; - inputTypes[2] = OMNI_INT; + DataTypes inputTypes(std::vector({ IntType(), VarcharType(), IntType() })); const int32_t numRows = 1; auto* col0 = new int32_t[numRows]; - auto* col1 = new int64_t[numRows]; + auto* col1 = new std::string[numRows]; auto* col2 = new int32_t[numRows]; col0[0] = pid; - std::string* strTmp = new std::string(strVar); - col1[0] = (int64_t)(strTmp->c_str()); + col1[0] = std::move(strVar); col2[0] = intVar; - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2)}; - VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2); delete[] col0; delete[] col1; delete[] col2; - delete strTmp; return in; } VectorBatch* CreateVectorBatch_4varcharCols_withPid(int parNum, int rowNum) { int partitionNum = parNum; - const int32_t numCols = 5; - int32_t* inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_VARCHAR; - inputTypes[2] = OMNI_VARCHAR; - inputTypes[3] = OMNI_VARCHAR; - inputTypes[4] = OMNI_VARCHAR; + DataTypes inputTypes( + std::vector({ IntType(), VarcharType(), VarcharType(), VarcharType(), VarcharType() })); const int32_t numRows = rowNum; auto* col0 = new int32_t[numRows]; - auto* col1 = new int64_t[numRows]; - auto* col2 = new int64_t[numRows]; - auto* col3 = new int64_t[numRows]; - auto* col4 = new int64_t[numRows]; + auto* col1 = new std::string[numRows]; + auto* col2 = new std::string[numRows]; + auto* col3 = new std::string[numRows]; + auto* col4 = new std::string[numRows]; - std::vector string_cache_test_; for (int i = 0; i < numRows; i++) { - col0[i] = (i+1) % partitionNum; - std::string* strTmp1 = new std::string("Col1_START_" + to_string(i + 1) + "_END_"); - col1[i] = (int64_t)((*strTmp1).c_str()); - std::string* strTmp2 = new std::string("Col2_START_" + to_string(i + 1) + "_END_"); - col2[i] = (int64_t)((*strTmp2).c_str()); - std::string* strTmp3 = new std::string("Col3_START_" + to_string(i + 1) + "_END_"); - col3[i] = (int64_t)((*strTmp3).c_str()); - std::string* strTmp4 = new std::string("Col4_START_" + to_string(i + 1) + "_END_"); - col4[i] = (int64_t)((*strTmp4).c_str()); - string_cache_test_.push_back(strTmp1); - string_cache_test_.push_back(strTmp2); - string_cache_test_.push_back(strTmp3); - string_cache_test_.push_back(strTmp4); + col0[i] = (i + 1) % partitionNum; + std::string strTmp1 = std::string("Col1_START_" + to_string(i + 1) + "_END_"); + col1[i] = std::move(strTmp1); + std::string strTmp2 = std::string("Col2_START_" + to_string(i + 1) + "_END_"); + col2[i] = std::move(strTmp2); + std::string strTmp3 = std::string("Col3_START_" + to_string(i + 1) + "_END_"); + col3[i] = std::move(strTmp3); + std::string strTmp4 = std::string("Col4_START_" + to_string(i + 1) + "_END_"); + col4[i] = std::move(strTmp4); } - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2), - reinterpret_cast(col3), - reinterpret_cast(col4)}; - VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2, col3, col4); delete[] col0; delete[] col1; delete[] col2; delete[] col3; delete[] col4; - - for (uint p = 0; p < string_cache_test_.size(); p++) { - delete string_cache_test_[p]; // release memory - } return in; } VectorBatch* CreateVectorBatch_4charCols_withPid(int parNum, int rowNum) { int partitionNum = parNum; - const int32_t numCols = 5; - int32_t* inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_CHAR; - inputTypes[2] = OMNI_CHAR; - inputTypes[3] = OMNI_CHAR; - inputTypes[4] = OMNI_CHAR; + DataTypes inputTypes(std::vector({ IntType(), CharType(), CharType(), CharType(), CharType() })); const int32_t numRows = rowNum; auto* col0 = new int32_t[numRows]; - auto* col1 = new int64_t[numRows]; - auto* col2 = new int64_t[numRows]; - auto* col3 = new int64_t[numRows]; - auto* col4 = new int64_t[numRows]; + auto* col1 = new std::string[numRows]; + auto* col2 = new std::string[numRows]; + auto* col3 = new std::string[numRows]; + auto* col4 = new std::string[numRows]; std::vector string_cache_test_; for (int i = 0; i < numRows; i++) { - col0[i] = (i+1) % partitionNum; - std::string* strTmp1 = new std::string("Col1_CHAR_" + to_string(i + 1) + "_END_"); - col1[i] = (int64_t)((*strTmp1).c_str()); - std::string* strTmp2 = new std::string("Col2_CHAR_" + to_string(i + 1) + "_END_"); - col2[i] = (int64_t)((*strTmp2).c_str()); - std::string* strTmp3 = new std::string("Col3_CHAR_" + to_string(i + 1) + "_END_"); - col3[i] = (int64_t)((*strTmp3).c_str()); - std::string* strTmp4 = new std::string("Col4_CHAR_" + to_string(i + 1) + "_END_"); - col4[i] = (int64_t)((*strTmp4).c_str()); - string_cache_test_.push_back(strTmp1); - string_cache_test_.push_back(strTmp2); - string_cache_test_.push_back(strTmp3); - string_cache_test_.push_back(strTmp4); + col0[i] = (i + 1) % partitionNum; + std::string strTmp1 = std::string("Col1_CHAR_" + to_string(i + 1) + "_END_"); + col1[i] = std::move(strTmp1); + std::string strTmp2 = std::string("Col2_CHAR_" + to_string(i + 1) + "_END_"); + col2[i] = std::move(strTmp2); + std::string strTmp3 = std::string("Col3_CHAR_" + to_string(i + 1) + "_END_"); + col3[i] = std::move(strTmp3); + std::string strTmp4 = std::string("Col4_CHAR_" + to_string(i + 1) + "_END_"); + col4[i] = std::move(strTmp4); } - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2), - reinterpret_cast(col3), - reinterpret_cast(col4)}; - VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2, col3, col4); delete[] col0; delete[] col1; delete[] col2; delete[] col3; delete[] col4; - - for (uint p = 0; p < string_cache_test_.size(); p++) { - delete string_cache_test_[p]; // release memory - } return in; } VectorBatch* CreateVectorBatch_5fixedCols_withPid(int parNum, int rowNum) { int partitionNum = parNum; - // gen vectorBatch - const int32_t numCols = 6; - int32_t* inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_BOOLEAN; - inputTypes[2] = OMNI_SHORT; - inputTypes[3] = OMNI_INT; - inputTypes[4] = OMNI_LONG; - inputTypes[5] = OMNI_DOUBLE; + DataTypes inputTypes( + std::vector({ IntType(), BooleanType(), ShortType(), IntType(), LongType(), DoubleType() })); const int32_t numRows = rowNum; auto* col0 = new int32_t[numRows]; @@ -490,14 +238,7 @@ VectorBatch* CreateVectorBatch_5fixedCols_withPid(int parNum, int rowNum) { col5[i] = i + 1; } - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2), - reinterpret_cast(col3), - reinterpret_cast(col4), - reinterpret_cast(col5)}; - VectorBatch* in = CreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2, col3, col4, col5); delete[] col0; delete[] col1; delete[] col2; @@ -512,71 +253,85 @@ VectorBatch* CreateVectorBatch_2dictionaryCols_withPid(int partitionNum) { // construct input data const int32_t dataSize = 6; // prepare data - int32_t data0[dataSize] = {111, 112, 113, 114, 115, 116}; - int64_t data1[dataSize] = {221, 222, 223, 224, 225, 226}; - void *datas[2] = {data0, data1}; - DataTypes sourceTypes(std::vector({ std::make_unique(), std::make_unique()})); - int32_t ids[] = {0, 1, 2, 3, 4, 5}; - VectorBatch *vectorBatch = new VectorBatch(3, dataSize); - VectorAllocator *allocator = omniruntime::vec::GetProcessGlobalVecAllocator(); - IntVector *intVectorTmp = new IntVector(allocator, 6); - for (int i = 0; i < intVectorTmp->GetSize(); i++) { - intVectorTmp->SetValue(i, (i+1) % partitionNum); - } - for (int32_t i = 0; i < 3; i ++) { - if (i == 0) { - vectorBatch->SetVector(i, intVectorTmp); - } else { - omniruntime::vec::DataType dataType = *(sourceTypes.Get()[i - 1]); - vectorBatch->SetVector(i, CreateDictionaryVector(dataType, dataSize, ids, dataSize, datas[i - 1])); - } + auto *col0 = new int32_t[dataSize]; + for (int32_t i = 0; i< dataSize; i++) { + col0[i] = (i + 1) % partitionNum; } + int32_t col1[dataSize] = {111, 112, 113, 114, 115, 116}; + int64_t col2[dataSize] = {221, 222, 223, 224, 225, 226}; + void *datas[2] = {col1, col2}; + DataTypes sourceTypes(std::vector({ IntType(), LongType() })); + int32_t ids[] = {0, 1, 2, 3, 4, 5}; + + VectorBatch *vectorBatch = new VectorBatch(dataSize); + auto Vec0 = CreateVector(dataSize, col0); + vectorBatch->Append(Vec0); + auto dicVec0 = CreateDictionaryVector(*sourceTypes.GetType(0), dataSize, ids, dataSize, datas[0]); + auto dicVec1 = CreateDictionaryVector(*sourceTypes.GetType(1), dataSize, ids, dataSize, datas[1]); + vectorBatch->Append(dicVec0); + vectorBatch->Append(dicVec1); + + delete[] col0; return vectorBatch; } VectorBatch* CreateVectorBatch_1decimal128Col_withPid(int partitionNum, int rowNum) { - auto decimal128InputVec = buildVector(Decimal128DataType(38, 2), rowNum); - VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); - IntVector *intVectorPid = new IntVector(allocator, rowNum); - for (int i = 0; i < intVectorPid->GetSize(); i++) { - intVectorPid->SetValue(i, (i+1) % partitionNum); + const int32_t numRows = rowNum; + DataTypes inputTypes(std::vector({ IntType(), Decimal128Type(38, 2) })); + + auto *col0 = new int32_t[numRows]; + auto *col1 = new Decimal128[numRows]; + for (int32_t i = 0; i < numRows; i++) { + col0[i] = (i + 1) % partitionNum; + col1[i] = Decimal128(0, 1); } - VectorBatch *vecBatch = new VectorBatch(2); - vecBatch->SetVector(0, intVectorPid); - vecBatch->SetVector(1, decimal128InputVec); - return vecBatch; + + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1); + delete[] col0; + delete[] col1; + return in; } VectorBatch* CreateVectorBatch_1decimal64Col_withPid(int partitionNum, int rowNum) { - auto decimal64InputVec = buildVector(Decimal64DataType(7, 2), rowNum); - VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); - IntVector *intVectorPid = new IntVector(allocator, rowNum); - for (int i = 0; i < intVectorPid->GetSize(); i++) { - intVectorPid->SetValue(i, (i+1) % partitionNum); + const int32_t numRows = rowNum; + DataTypes inputTypes(std::vector({ IntType(), Decimal64Type(7, 2) })); + + auto *col0 = new int32_t[numRows]; + auto *col1 = new int64_t[numRows]; + for (int32_t i = 0; i < numRows; i++) { + col0[i] = (i + 1) % partitionNum; + col1[i] = 1; } - VectorBatch *vecBatch = new VectorBatch(2); - vecBatch->SetVector(0, intVectorPid); - vecBatch->SetVector(1, decimal64InputVec); - return vecBatch; + + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1); + delete[] col0; + delete[] col1; + return in; } VectorBatch* CreateVectorBatch_2decimalCol_withPid(int partitionNum, int rowNum) { - auto decimal64InputVec = buildVector(Decimal64DataType(7, 2), rowNum); - auto decimal128InputVec = buildVector(Decimal128DataType(38, 2), rowNum); - VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); - IntVector *intVectorPid = new IntVector(allocator, rowNum); - for (int i = 0; i < intVectorPid->GetSize(); i++) { - intVectorPid->SetValue(i, (i+1) % partitionNum); + const int32_t numRows = rowNum; + DataTypes inputTypes(std::vector({ IntType(), Decimal64Type(7, 2), Decimal128Type(38, 2) })); + + auto *col0 = new int32_t[numRows]; + auto *col1 = new int64_t[numRows]; + auto *col2 = new Decimal128[numRows]; + for (int32_t i = 0; i < numRows; i++) { + col0[i] = (i + 1) % partitionNum; + col1[i] = 1; + col2[i] = Decimal128(0, 1); } - VectorBatch *vecBatch = new VectorBatch(3); - vecBatch->SetVector(0, intVectorPid); - vecBatch->SetVector(1, decimal64InputVec); - vecBatch->SetVector(2, decimal128InputVec); - return vecBatch; + + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2); + delete[] col0; + delete[] col1; + delete[] col2; + return in; } VectorBatch* CreateVectorBatch_someNullRow_vectorBatch() { const int32_t numRows = 6; + const int32_t numCols = 6; bool data0[numRows] = {true, false, true, false, true, false}; int16_t data1[numRows] = {0, 1, 2, 3, 4, 6}; int32_t data2[numRows] = {0, 1, 2, 0, 1, 2}; @@ -584,50 +339,32 @@ VectorBatch* CreateVectorBatch_someNullRow_vectorBatch() { double data4[numRows] = {0.0, 1.1, 2.2, 3.3, 4.4, 5.5}; std::string data5[numRows] = {"abcde", "fghij", "klmno", "pqrst", "", ""}; - auto vec0 = CreateVector(data0, numRows); - auto vec1 = CreateVector(data1, numRows); - auto vec2 = CreateVector(data2, numRows); - auto vec3 = CreateVector(data3, numRows); - auto vec4 = CreateVector(data4, numRows); - auto vec5 = CreateVarcharVector(VarcharDataType(5), data5, numRows); - for (int i = 0; i < numRows; i = i + 2) { - vec0->SetValueNull(i); - vec1->SetValueNull(i); - vec2->SetValueNull(i); - vec3->SetValueNull(i); - vec4->SetValueNull(i); - vec5->SetValueNull(i); + DataTypes inputTypes( + std::vector({ BooleanType(), ShortType(), IntType(), LongType(), DoubleType(), VarcharType(5) })); + VectorBatch* vecBatch = CreateVectorBatch(inputTypes, numRows, data0, data1, data2, data3, data4, data5); + for (int32_t i = 0; i < numCols; i++) { + for (int32_t j = 0; j < numRows; j = j + 2) { + vecBatch->Get(i)->SetNull(j); + } } - VectorBatch *vecBatch = new VectorBatch(6); - vecBatch->SetVector(0, vec0); - vecBatch->SetVector(1, vec1); - vecBatch->SetVector(2, vec2); - vecBatch->SetVector(3, vec3); - vecBatch->SetVector(4, vec4); - vecBatch->SetVector(5, vec5); return vecBatch; } VectorBatch* CreateVectorBatch_someNullCol_vectorBatch() { const int32_t numRows = 6; + const int32_t numCols = 4; int32_t data1[numRows] = {0, 1, 2, 0, 1, 2}; int64_t data2[numRows] = {0, 1, 2, 3, 4, 5}; double data3[numRows] = {0.0, 1.1, 2.2, 3.3, 4.4, 5.5}; std::string data4[numRows] = {"abcde", "fghij", "klmno", "pqrst", "", ""}; - auto vec0 = CreateVector(data1, numRows); - auto vec1 = CreateVector(data2, numRows); - auto vec2 = CreateVector(data3, numRows); - auto vec3 = CreateVarcharVector(VarcharDataType(5), data4, numRows); - for (int i = 0; i < numRows; i = i + 1) { - vec1->SetValueNull(i); - vec3->SetValueNull(i); + DataTypes inputTypes(std::vector({ IntType(), LongType(), DoubleType(), VarcharType(5) })); + VectorBatch* vecBatch = CreateVectorBatch(inputTypes, numRows, data1, data2, data3, data4); + for (int32_t i = 0; i < numCols; i = i + 2) { + for (int32_t j = 0; j < numRows; j++) { + vecBatch->Get(i)->SetNull(j); + } } - VectorBatch *vecBatch = new VectorBatch(4); - vecBatch->SetVector(0, vec0); - vecBatch->SetVector(1, vec1); - vecBatch->SetVector(2, vec2); - vecBatch->SetVector(3, vec3); return vecBatch; } @@ -687,17 +424,17 @@ long Test_splitter_nativeMake(std::string partitioning_name, splitOptions.compression_type = compression_type_result; splitOptions.data_file = data_file_jstr; auto splitter = Splitter::Make(partitioning_name, inputDataTypes, numCols, num_partitions, std::move(splitOptions)); - return shuffle_splitter_holder_.Insert(std::shared_ptr(splitter)); + return testShuffleSplitterHolder.Insert(std::shared_ptr(splitter)); } void Test_splitter_split(long splitter_id, VectorBatch* vb) { - auto splitter = shuffle_splitter_holder_.Lookup(splitter_id); - //初始化split各全局变量 + auto splitter = testShuffleSplitterHolder.Lookup(splitter_id); + // Initialize split global variables splitter->Split(*vb); } void Test_splitter_stop(long splitter_id) { - auto splitter = shuffle_splitter_holder_.Lookup(splitter_id); + auto splitter = testShuffleSplitterHolder.Lookup(splitter_id); if (!splitter) { std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); throw std::runtime_error("Test no splitter."); @@ -706,12 +443,12 @@ void Test_splitter_stop(long splitter_id) { } void Test_splitter_close(long splitter_id) { - auto splitter = shuffle_splitter_holder_.Lookup(splitter_id); + auto splitter = testShuffleSplitterHolder.Lookup(splitter_id); if (!splitter) { std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); throw std::runtime_error("Test no splitter."); } - shuffle_splitter_holder_.Erase(splitter_id); + testShuffleSplitterHolder.Erase(splitter_id); } void GetFilePath(const char *path, const char *filename, char *filepath) { diff --git a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h index 496a4cc6fc6d0a8834a95db72ccccb5376fe02b6..b7380254a687ed6f3eaf8234df944feac9087404 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h +++ b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h @@ -25,22 +25,70 @@ #include #include #include -#include "../../src/shuffle/splitter.h" -#include "../../src/jni/concurrent_map.h" +#include "shuffle/splitter.h" +#include "jni/concurrent_map.h" -static ConcurrentMap> shuffle_splitter_holder_; +static ConcurrentMap> testShuffleSplitterHolder; static std::string s_shuffle_tests_dir = "/tmp/shuffleTests"; -VectorBatch* CreateInputData(const int32_t numRows, const int32_t numCols, int32_t* inputTypeIds, int64_t* allData); +VectorBatch *CreateVectorBatch(const DataTypes &types, int32_t rowCount, ...); -Vector *buildVector(const DataType &aggType, int32_t rowNumber); +BaseVector *CreateVector(DataType &dataType, int32_t rowCount, va_list &args); + +template BaseVector *CreateVector(int32_t length, T *values) +{ + Vector *vector = new Vector(length); + for (int32_t i = 0; i < length; i++) { + vector->SetValue(i, values[i]); + } + return vector; +} + +template +BaseVector *CreateFlatVector(int32_t length, va_list &args) +{ + using namespace omniruntime::type; + using T = typename NativeType::type; + using VarcharVector = Vector>; + if constexpr (std::is_same_v) { + VarcharVector *vector = new VarcharVector(length); + std::string *str = va_arg(args, std::string *); + for (int32_t i = 0; i < length; i++) { + std::string_view value(str[i].data(), str[i].length()); + vector->SetValue(i, value); + } + return vector; + } else { + Vector *vector = new Vector(length); + T *value = va_arg(args, T *); + for (int32_t i = 0; i < length; i++) { + vector->SetValue(i, value[i]); + } + return vector; + } +} + +BaseVector *CreateDictionaryVector(DataType &dataType, int32_t rowCount, int32_t *ids, int32_t idsCount, + ...); + +template +BaseVector *CreateDictionary(BaseVector *vector, int32_t *ids, int32_t size) +{ + using T = typename NativeType::type; + if constexpr (std::is_same_v) { + return VectorHelper::CreateStringDictionary(ids, size, + reinterpret_cast> *>(vector)); + } else { + return VectorHelper::CreateDictionary(ids, size, reinterpret_cast *>(vector)); + } +} VectorBatch* CreateVectorBatch_1row_varchar_withPid(int pid, std::string inputChar); VectorBatch* CreateVectorBatch_4col_withPid(int parNum, int rowNum); -VectorBatch* CreateVectorBatch_1FixCol_withPid(int parNum, int rowNum, int32_t fixColType); +VectorBatch* CreateVectorBatch_1FixCol_withPid(int parNum, int rowNum, DataTypePtr fixColType); VectorBatch* CreateVectorBatch_2column_1row_withPid(int pid, std::string strVar, int intVar); @@ -79,14 +127,6 @@ void Test_splitter_stop(long splitter_id); void Test_splitter_close(long splitter_id); -template T *CreateVector(V *values, int32_t length) -{ - VectorAllocator *vecAllocator = omniruntime::vec::GetProcessGlobalVecAllocator(); - auto vector = new T(vecAllocator, length); - vector->SetValues(0, values, length); - return vector; -} - void GetFilePath(const char *path, const char *filename, char *filepath); void DeletePathAll(const char* path); diff --git a/omnioperator/omniop-spark-extension/java/pom.xml b/omnioperator/omniop-spark-extension/java/pom.xml index caafa313fbd2cb88b124e370f2d73460199b7051..44c8236b619810b1f8c52f7d6c4d117052ad80a0 100644 --- a/omnioperator/omniop-spark-extension/java/pom.xml +++ b/omnioperator/omniop-spark-extension/java/pom.xml @@ -7,7 +7,7 @@ com.huawei.kunpeng boostkit-omniop-spark-parent - 3.1.1-1.1.0 + 3.1.1-1.3.0 ../pom.xml @@ -29,33 +29,6 @@ - - - commons-beanutils - commons-beanutils - 1.9.4 - - - org.checkerframework - checker-qual - 3.8.0 - - - com.google.errorprone - error_prone_annotations - 2.4.0 - - - com.google.guava - guava - 31.0.1-jre - - - xerces - xercesImpl - 2.12.2 - - org.apache.spark spark-sql_${scala.binary.version} @@ -103,20 +76,20 @@ spark-core_${scala.binary.version} test-jar test - 3.1.1 + ${spark.version} org.apache.spark spark-catalyst_${scala.binary.version} test-jar test - 3.1.1 + ${spark.version} org.apache.spark spark-sql_${scala.binary.version} test-jar - 3.1.1 + ${spark.version} test @@ -127,7 +100,7 @@ org.apache.spark spark-hive_${scala.binary.version} - 3.1.1 + ${spark.version} provided diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReader.java index fa5cb11b248cc70f3253d620649c96b6a4f6a0ac..d80a236533c6b2b3305b2f443b759877239d6089 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReader.java @@ -19,7 +19,6 @@ package com.huawei.boostkit.spark.jni; import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.type.Decimal128DataType; import nova.hetu.omniruntime.vector.*; import org.apache.spark.sql.catalyst.util.RebaseDateTime; @@ -117,7 +116,7 @@ public class OrcColumnarBatchJniReader { lst.add(spiltValues[0] + "." + scalePadZeroStr + " " + decimalP + " " + decimalS); } } else if (pl.getType() == PredicateLeaf.Type.DATE) { - lst.add(((int)Math.ceil(((Date)pl.getLiteral()).getTime()* 1.0/3600/24/1000)) + ""); + lst.add(((int)Math.ceil(((Date)ob).getTime()* 1.0/3600/24/1000)) + ""); } else { lst.add(ob.toString()); } @@ -273,7 +272,7 @@ public class OrcColumnarBatchJniReader { break; } case OMNI_DECIMAL128: { - vecList[i] = new Decimal128Vec(vecNativeIds[nativeGetId], Decimal128DataType.DECIMAL128); + vecList[i] = new Decimal128Vec(vecNativeIds[nativeGetId]); break; } default: { diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReader.java new file mode 100644 index 0000000000000000000000000000000000000000..3a5cffb09c4792e2731dcd9ab8b377417f888b4c --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReader.java @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2021-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.jni; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.vector.*; + +import org.apache.spark.sql.catalyst.util.RebaseDateTime; + +import org.json.JSONObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +public class ParquetColumnarBatchJniReader { + private static final Logger LOGGER = LoggerFactory.getLogger(ParquetColumnarBatchJniReader.class); + + public long parquetReader; + + public ParquetColumnarBatchJniReader() { + NativeLoader.getInstance(); + } + + public long initializeReaderJava(String path, int capacity, + List rowgroupIndices, List columnIndices, String ugi) { + JSONObject job = new JSONObject(); + job.put("filePath", path); + job.put("capacity", capacity); + job.put("rowGroupIndices", rowgroupIndices.stream().mapToInt(Integer::intValue).toArray()); + job.put("columnIndices", columnIndices.stream().mapToInt(Integer::intValue).toArray()); + job.put("ugi", ugi); + parquetReader = initializeReader(job); + return parquetReader; + } + + public int next(Vec[] vecList) { + int vectorCnt = vecList.length; + int[] typeIds = new int[vectorCnt]; + long[] vecNativeIds = new long[vectorCnt]; + long rtn = recordReaderNext(parquetReader, typeIds, vecNativeIds); + if (rtn == 0) { + return 0; + } + int nativeGetId = 0; + for (int i = 0; i < vectorCnt; i++) { + switch (DataType.DataTypeId.values()[typeIds[nativeGetId]]) { + case OMNI_BOOLEAN: { + vecList[i] = new BooleanVec(vecNativeIds[nativeGetId]); + break; + } + case OMNI_SHORT: { + vecList[i] = new ShortVec(vecNativeIds[nativeGetId]); + break; + } + case OMNI_DATE32: { + vecList[i] = new IntVec(vecNativeIds[nativeGetId]); + break; + } + case OMNI_INT: { + vecList[i] = new IntVec(vecNativeIds[nativeGetId]); + break; + } + case OMNI_LONG: + case OMNI_DECIMAL64: { + vecList[i] = new LongVec(vecNativeIds[nativeGetId]); + break; + } + case OMNI_DOUBLE: { + vecList[i] = new DoubleVec(vecNativeIds[nativeGetId]); + break; + } + case OMNI_VARCHAR: { + vecList[i] = new VarcharVec(vecNativeIds[nativeGetId]); + break; + } + case OMNI_DECIMAL128: { + vecList[i] = new Decimal128Vec(vecNativeIds[nativeGetId]); + break; + } + default: { + throw new RuntimeException("UnSupport type for ColumnarFileScan:" + + DataType.DataTypeId.values()[typeIds[i]]); + } + } + nativeGetId++; + } + return (int)rtn; + } + + public void close() { + recordReaderClose(parquetReader); + } + + public native long initializeReader(JSONObject job); + + public native long recordReaderNext(long parquetReader, int[] typeId, long[] vecNativeId); + + public native void recordReaderClose(long parquetReader); + +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java index 10cdb08496dbf7a0229b075feb598616a5b86f79..c170b04e4a4b678d962200772cf0c542bed591c4 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java @@ -75,6 +75,8 @@ public class OmniOrcColumnarBatchReader extends RecordReader { + + // The capacity of vectorized batch. + private int capacity; + private FilterCompat.Filter filter; + private ParquetMetadata fileFooter; + private boolean[] missingColumns; + private ColumnarBatch columnarBatch; + private MessageType fileSchema; + private MessageType requestedSchema; + private StructType sparkSchema; + private ParquetColumnarBatchJniReader reader; + private org.apache.spark.sql.vectorized.ColumnVector[] wrap; + + // Store the immutable cols, such as partionCols and misingCols, which only init once. + // And wrap will slice vecs from templateWrap when calling nextBatch(). + private org.apache.spark.sql.vectorized.ColumnVector[] templateWrap; + private Vec[] vecs; + private boolean isFilterPredicate = false; + + public OmniParquetColumnarBatchReader(int capacity, ParquetMetadata fileFooter) { + this.capacity = capacity; + this.fileFooter = fileFooter; + } + + public ParquetColumnarBatchJniReader getReader() { + return this.reader; + } + + @Override + public void close() throws IOException { + if (reader != null) { + reader.close(); + reader = null; + } + // Free vecs from templateWrap. + for (int i = 0; i < templateWrap.length; i++) { + OmniColumnVector vector = (OmniColumnVector) templateWrap[i]; + vector.close(); + } + } + + @Override + public Void getCurrentKey() { + return null; + } + + @Override + public ColumnarBatch getCurrentValue() { + return columnarBatch; + } + + @Override + public boolean nextKeyValue() throws IOException { + return nextBatch(); + } + + @Override + public float getProgress() throws IOException { + return 0; + } + + /** + * Implementation of RecordReader API. + */ + @Override + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException, UnsupportedOperationException { + Configuration configuration = taskAttemptContext.getConfiguration(); + ParquetInputSplit split = (ParquetInputSplit)inputSplit; + + this.filter = getFilter(configuration); + this.isFilterPredicate = filter instanceof FilterCompat.FilterPredicateCompat ? true : false; + + this.fileSchema = fileFooter.getFileMetaData().getSchema(); + Map fileMetadata = fileFooter.getFileMetaData().getKeyValueMetaData(); + ReadSupport readSupport = getReadSupportInstance(getReadSupportClass(configuration)); + ReadSupport.ReadContext readContext = readSupport.init(new InitContext( + taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); + this.requestedSchema = readContext.getRequestedSchema(); + String sparkRequestedSchemaString = configuration.get(ParquetReadSupport$.MODULE$.SPARK_ROW_REQUESTED_SCHEMA()); + this.sparkSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString); + this.reader = new ParquetColumnarBatchJniReader(); + // PushDown rowGroups and columns indices for native reader. + List rowgroupIndices = getFilteredBlocks(split.getStart(), split.getEnd()); + List columnIndices = getColumnIndices(requestedSchema.getColumns(), fileSchema.getColumns()); + String ugi = UserGroupInformation.getCurrentUser().toString(); + reader.initializeReaderJava(split.getPath().toString(), capacity, rowgroupIndices, columnIndices, ugi); + // Add missing Cols flags. + initializeInternal(); + } + + private List getFilteredBlocks(long start, long end) throws IOException, InterruptedException { + List res = new ArrayList<>(); + List blocks = fileFooter.getBlocks(); + for (int i = 0; i < blocks.size(); i++) { + BlockMetaData block = blocks.get(i); + long totalSize = 0; + long startIndex = block.getStartingPos(); + for (ColumnChunkMetaData col : block.getColumns()) { + totalSize += col.getTotalSize(); + } + long midPoint = startIndex + totalSize / 2; + if (midPoint >= start && midPoint < end) { + if (isFilterPredicate) { + boolean drop = StatisticsFilter.canDrop(((FilterCompat.FilterPredicateCompat) filter).getFilterPredicate(), + block.getColumns()); + if (!drop) { + res.add(i); + } + } else { + res.add(i); + } + } + } + return res; + } + + private List getColumnIndices(List requestedColumns, List allColumns) { + List res = new ArrayList<>(); + for (int i = 0; i < requestedColumns.size(); i++) { + ColumnDescriptor it = requestedColumns.get(i); + for (int j = 0; j < allColumns.size(); j++) { + if (it.toString().equals(allColumns.get(j).toString())) { + res.add(j); + break; + } + } + } + + if (res.size() != requestedColumns.size()) { + throw new ParquetDecodingException("Parquet mapping column indices error"); + } + return res; + } + + private void initializeInternal() throws IOException, UnsupportedOperationException { + // Check that the requested schema is supported. + missingColumns = new boolean[requestedSchema.getFieldCount()]; + List columns = requestedSchema.getColumns(); + List paths = requestedSchema.getPaths(); + for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { + Type t = requestedSchema.getFields().get(i); + if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { + throw new UnsupportedOperationException("Complex types not supported."); + } + + String[] colPath = paths.get(i); + if (fileSchema.containsPath(colPath)) { + ColumnDescriptor fd = fileSchema.getColumnDescription(colPath); + if (!fd.equals(columns.get(i))) { + throw new UnsupportedOperationException("Schema evolution not supported."); + } + missingColumns[i] = false; + } else { + if (columns.get(i).getMaxDefinitionLevel() == 0) { + // Column is missing in data but the required data is non-nullable. This file is invalid. + throw new IOException("Required column is missing in data file. Col: " + Arrays.toString(colPath)); + } + missingColumns[i] = true; + } + } + } + + // Creates a columnar batch that includes the schema from the data files and the additional + // partition columns appended to the end of the batch. + // For example, if the data contains two columns, with 2 partition columns: + // Columns 0,1: data columns + // Column 2: partitionValues[0] + // Column 3: partitionValues[1] + public void initBatch(StructType partitionColumns, InternalRow partitionValues) { + StructType batchSchema = new StructType(); + for (StructField f: sparkSchema.fields()) { + batchSchema = batchSchema.add(f); + } + if (partitionColumns != null) { + for (StructField f : partitionColumns.fields()) { + batchSchema = batchSchema.add(f); + } + } + wrap = new org.apache.spark.sql.vectorized.ColumnVector[batchSchema.length()]; + columnarBatch = new ColumnarBatch(wrap); + // Init template also + templateWrap = new org.apache.spark.sql.vectorized.ColumnVector[batchSchema.length()]; + // Init partition columns + if (partitionColumns != null) { + int partitionIdx = sparkSchema.fields().length; + for (int i = 0; i < partitionColumns.fields().length; i++) { + OmniColumnVector partitionCol = new OmniColumnVector(capacity, partitionColumns.fields()[i].dataType(), true); + ColumnVectorUtils.populate(partitionCol, partitionValues, i); + partitionCol.setIsConstant(); + // templateWrap always stores partitionCol + templateWrap[i + partitionIdx] = partitionCol; + // wrap also need to new partitionCol but not init vec + wrap[i + partitionIdx] = new OmniColumnVector(capacity, partitionColumns.fields()[i].dataType(), false); + } + } + + // Initialize missing columns with nulls. + for (int i = 0; i < missingColumns.length; i++) { + // templateWrap always stores missingCol. For other requested cols from native, it will not init them. + if (missingColumns[i]) { + OmniColumnVector missingCol = new OmniColumnVector(capacity, sparkSchema.fields()[i].dataType(), true); + missingCol.putNulls(0, capacity); + missingCol.setIsConstant(); + templateWrap[i] = missingCol; + } else { + templateWrap[i] = new OmniColumnVector(capacity, sparkSchema.fields()[i].dataType(), false); + } + + // wrap also need to new partitionCol but not init vec + wrap[i] = new OmniColumnVector(capacity, sparkSchema.fields()[i].dataType(), false); + } + vecs = new Vec[requestedSchema.getFieldCount()]; + } + + /** + * Advance to the next batch of rows. Return false if there are no more. + */ + public boolean nextBatch() throws IOException { + int batchSize = reader.next(vecs); + if (batchSize == 0) { + return false; + } + columnarBatch.setNumRows(batchSize); + + for (int i = 0; i < requestedSchema.getFieldCount(); i++) { + if (!missingColumns[i]) { + ((OmniColumnVector) wrap[i]).setVec(vecs[i]); + } + } + + // Slice other vecs from templateWrap. + for (int i = 0; i < templateWrap.length; i++) { + OmniColumnVector vector = (OmniColumnVector) templateWrap[i]; + if (vector.isConstant()) { + ((OmniColumnVector) wrap[i]).setVec(vector.getVec().slice(0, batchSize)); + } + } + return true; + } + + private static Map> toSetMultiMap(Map map) { + Map> setMultiMap = new HashMap<>(); + for (Map.Entry entry : map.entrySet()) { + Set set = new HashSet<>(); + set.add(entry.getValue()); + setMultiMap.put(entry.getKey(), Collections.unmodifiableSet(set)); + } + return Collections.unmodifiableMap(setMultiMap); + } + + @SuppressWarnings("unchecked") + private Class> getReadSupportClass(Configuration configuration) { + return (Class>) ConfigurationUtil.getClassFromConfig(configuration, + ParquetInputFormat.READ_SUPPORT_CLASS, ReadSupport.class); + } + + /** + * @param readSupportClass to instantiate + * @return the configured read support + */ + private static ReadSupport getReadSupportInstance(Class> readSupportClass) { + try { + return readSupportClass.getConstructor().newInstance(); + } catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java index 808f96e1fb666def4ff9fc224f01020a81a5baf7..35f46f04ec89f5eaced46c61affde30cfe54d5cc 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java @@ -162,24 +162,31 @@ public class OmniColumnVector extends WritableColumnVector { super.close(); if (booleanDataVec != null) { booleanDataVec.close(); + booleanDataVec = null; } if (shortDataVec != null) { shortDataVec.close(); + shortDataVec = null; } if (intDataVec != null) { intDataVec.close(); + intDataVec = null; } if (longDataVec != null) { longDataVec.close(); + longDataVec = null; } if (doubleDataVec != null) { doubleDataVec.close(); + doubleDataVec = null; } if (decimal128DataVec != null) { decimal128DataVec.close(); + decimal128DataVec = null; } if (charsTypeDataVec != null) { charsTypeDataVec.close(); + charsTypeDataVec = null; } if (dictionaryData != null) { dictionaryData.close(); @@ -194,32 +201,32 @@ public class OmniColumnVector extends WritableColumnVector { @Override public boolean hasNull() { if (dictionaryData != null) { - return dictionaryData.hasNullValue(); + return dictionaryData.hasNull(); } if (type instanceof BooleanType) { - return booleanDataVec.hasNullValue(); + return booleanDataVec.hasNull(); } else if (type instanceof ByteType) { - return charsTypeDataVec.hasNullValue(); + return charsTypeDataVec.hasNull(); } else if (type instanceof ShortType) { - return shortDataVec.hasNullValue(); + return shortDataVec.hasNull(); } else if (type instanceof IntegerType) { - return intDataVec.hasNullValue(); + return intDataVec.hasNull(); } else if (type instanceof DecimalType) { if (DecimalType.is64BitDecimalType(type)) { - return longDataVec.hasNullValue(); + return longDataVec.hasNull(); } else { - return decimal128DataVec.hasNullValue(); + return decimal128DataVec.hasNull(); } } else if (type instanceof LongType || DecimalType.is64BitDecimalType(type)) { - return longDataVec.hasNullValue(); + return longDataVec.hasNull(); } else if (type instanceof FloatType) { return false; } else if (type instanceof DoubleType) { - return doubleDataVec.hasNullValue(); + return doubleDataVec.hasNull(); } else if (type instanceof StringType) { - return charsTypeDataVec.hasNullValue(); + return charsTypeDataVec.hasNull(); } else if (type instanceof DateType) { - return intDataVec.hasNullValue(); + return intDataVec.hasNull(); } throw new UnsupportedOperationException("hasNull is not supported for type:" + type); } @@ -267,8 +274,8 @@ public class OmniColumnVector extends WritableColumnVector { @Override public void putNulls(int rowId, int count) { - boolean[] nullValue = new boolean[count]; - Arrays.fill(nullValue, true); + byte[] nullValue = new byte[count]; + Arrays.fill(nullValue, (byte) 1); if (dictionaryData != null) { dictionaryData.setNulls(rowId, nullValue, 0, count); return; @@ -749,7 +756,18 @@ public class OmniColumnVector extends WritableColumnVector { @Override public int putByteArray(int rowId, byte[] value, int offset, int length) { - throw new UnsupportedOperationException("putByteArray is not supported"); + if (type instanceof StringType) { + putBytes(rowId, length, value, offset); + return length; + } else if (type instanceof DecimalType && DecimalType.isByteArrayDecimalType(type)) { + byte[] array = new byte[length]; + System.arraycopy(value, offset, array, 0, length); + BigInteger bigInteger = new BigInteger(array); + decimal128DataVec.setBigInteger(rowId, bigInteger); + return length; + } else { + throw new UnsupportedOperationException("putByteArray is not supported for type" + type); + } } /** @@ -806,7 +824,7 @@ public class OmniColumnVector extends WritableColumnVector { if (type instanceof BooleanType) { booleanDataVec = new BooleanVec(newCapacity); } else if (type instanceof ByteType) { - charsTypeDataVec = new VarcharVec(newCapacity * 4, newCapacity); + charsTypeDataVec = new VarcharVec(newCapacity); } else if (type instanceof ShortType) { shortDataVec = new ShortVec(newCapacity); } else if (type instanceof IntegerType) { @@ -825,7 +843,7 @@ public class OmniColumnVector extends WritableColumnVector { doubleDataVec = new DoubleVec(newCapacity); } else if (type instanceof StringType) { // need to set with real column size, suppose char(200) utf8 - charsTypeDataVec = new VarcharVec(newCapacity * 4 * 200, newCapacity); + charsTypeDataVec = new VarcharVec(newCapacity); } else if (type instanceof DateType) { intDataVec = new IntVec(newCapacity); } else { @@ -838,4 +856,8 @@ public class OmniColumnVector extends WritableColumnVector { protected OmniColumnVector reserveNewColumn(int capacity, DataType type) { return new OmniColumnVector(capacity, type, true); } + + public boolean isConstant() { + return isConstant; + } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala index a4e4eaa0a877f7ee2e3401ecf4ee98fecfcb7314..e48f958c89fd79370a01a5f2fe2bcae56716510f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala @@ -56,6 +56,7 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { columnarConf.enableColumnarBroadcastJoin val enableColumnarBroadcastJoin: Boolean = columnarConf.enableColumnarBroadcastJoin val enableColumnarSortMergeJoin: Boolean = columnarConf.enableColumnarSortMergeJoin + val enableSortMergeJoinFusion: Boolean = columnarConf.enableSortMergeJoinFusion val enableShuffledHashJoin: Boolean = columnarConf.enableShuffledHashJoin val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan val optimizeLevel: Integer = columnarConf.joinOptimizationThrottle @@ -104,7 +105,7 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { plan.child, plan.testSpillFrequency).buildCheck() case plan: BroadcastExchangeExec => if (!enableColumnarBroadcastExchange) return false - new ColumnarBroadcastExchangeExec(plan.mode, plan.child) + new ColumnarBroadcastExchangeExec(plan.mode, plan.child).buildCheck() case plan: TakeOrderedAndProjectExec => if (!enableTakeOrderedAndProject) return false ColumnarTakeOrderedAndProjectExec( @@ -161,14 +162,25 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { plan.right).buildCheck() case plan: SortMergeJoinExec => if (!enableColumnarSortMergeJoin) return false - new ColumnarSortMergeJoinExec( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.condition, - plan.left, - plan.right, - plan.isSkewJoin).buildCheck() + if (enableSortMergeJoinFusion) { + new ColumnarSortMergeJoinFusionExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + plan.left, + plan.right, + plan.isSkewJoin).buildCheck() + } else { + new ColumnarSortMergeJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + plan.left, + plan.right, + plan.isSkewJoin).buildCheck() + } case plan: WindowExec => if (!enableColumnarWindow) return false ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index fca65b3723f4fbad0e993d268e0fbc76a300bb92..c4b5e87b2b0fdfeba0e840a07636136f6c3daa75 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -17,21 +17,27 @@ package com.huawei.boostkit.spark +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} -import org.apache.spark.sql.catalyst.expressions.DynamicPruningSubquery -import org.apache.spark.sql.catalyst.expressions.aggregate.Partial +import org.apache.spark.sql.catalyst.expressions.{Ascending, DynamicPruningSubquery, SortOrder} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Partial} +import org.apache.spark.sql.catalyst.optimizer.{DelayCartesianProduct, HeuristicJoinReorder} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowToOmniColumnarExec, _} import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ColumnarCustomShuffleReaderExec, CustomShuffleReaderExec, QueryStageExec, ShuffleQueryStageExec} -import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.aggregate.{DummyLogicalPlan, ExtendedAggUtils, HashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalAggregation +import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.plans.logical.Aggregate -case class ColumnarPreOverrides() extends Rule[SparkPlan] { +case class ColumnarPreOverrides() extends Rule[SparkPlan] with PredicateHelper{ val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan val enableColumnarProject: Boolean = columnarConf.enableColumnarProject @@ -44,6 +50,7 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { columnarConf.enableColumnarBroadcastJoin val enableColumnarBroadcastJoin: Boolean = columnarConf.enableColumnarBroadcastJoin && columnarConf.enableColumnarBroadcastExchange + val enableSortMergeJoinFusion: Boolean = columnarConf.enableSortMergeJoinFusion val enableColumnarSortMergeJoin: Boolean = columnarConf.enableColumnarSortMergeJoin val enableColumnarSort: Boolean = columnarConf.enableColumnarSort val enableColumnarWindow: Boolean = columnarConf.enableColumnarWindow @@ -52,6 +59,11 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { val enableColumnarUnion: Boolean = columnarConf.enableColumnarUnion val enableFusion: Boolean = columnarConf.enableFusion var isSupportAdaptive: Boolean = true + val enableColumnarProjectFusion: Boolean = columnarConf.enableColumnarProjectFusion + val enableColumnarTopNSort: Boolean = columnarConf.enableColumnarTopNSort + val topNSortThreshold: Int = columnarConf.topNSortThreshold + val enableDedupLeftSemiJoin: Boolean = columnarConf.enableDedupLeftSemiJoin + val dedupLeftSemiJoinThreshold: Int = columnarConf.dedupLeftSemiJoinThreshold def apply(plan: SparkPlan): SparkPlan = { replaceWithColumnarPlan(plan) @@ -66,6 +78,19 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { } } + def isTopNExpression(e: Expression): Boolean = e match { + case Alias(child, _) => isTopNExpression(child) + case WindowExpression(windowFunction, _) + if windowFunction.isInstanceOf[Rank] => + true + case _ => false + } + + def isStrictTopN(e: Expression): Boolean = e match { + case Alias(child, _) => isStrictTopN(child) + case WindowExpression(windowFunction, _) => windowFunction.isInstanceOf[RowNumber] + } + def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match { case plan: RowGuard => val actualPlan: SparkPlan = plan.child match { @@ -118,13 +143,124 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { child match { case ColumnarFilterExec(condition, child) => ColumnarConditionProjectExec(plan.projectList, condition, child) + case join : ColumnarBroadcastHashJoinExec => + if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { + ColumnarBroadcastHashJoinExec( + join.leftKeys, + join.rightKeys, + join.joinType, + join.buildSide, + join.condition, + join.left, + join.right, + join.isNullAwareAntiJoin, + plan.projectList) + } else { + ColumnarProjectExec(plan.projectList, child) + } + case join : ColumnarShuffledHashJoinExec => + if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { + ColumnarShuffledHashJoinExec( + join.leftKeys, + join.rightKeys, + join.joinType, + join.buildSide, + join.condition, + join.left, + join.right, + plan.projectList) + } else { + ColumnarProjectExec(plan.projectList, child) + } + case join : ColumnarSortMergeJoinExec => + if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { + if(enableSortMergeJoinFusion && join.left.isInstanceOf[SortExec] && join.right.isInstanceOf[SortExec]) { + val left = replaceWithColumnarPlan(join.left.asInstanceOf[SortExec]) + val right = replaceWithColumnarPlan(join.right.asInstanceOf[SortExec]) + ColumnarSortMergeJoinFusionExec( + join.leftKeys, + join.rightKeys, + join.joinType, + join.condition, + left, + right, + join.isSkewJoin, + plan.projectList) + } else { + ColumnarSortMergeJoinExec( + join.leftKeys, + join.rightKeys, + join.joinType, + join.condition, + join.left, + join.right, + join.isSkewJoin, + plan.projectList) + } + } else { + ColumnarProjectExec(plan.projectList, child) + } case _ => ColumnarProjectExec(plan.projectList, child) } case plan: FilterExec if enableColumnarFilter => + if(enableColumnarTopNSort) { + val filterExec = plan.transform { + case f@FilterExec(condition, + w@WindowExec(Seq(windowExpression), _, orderSpec, sort: SortExec)) + if orderSpec.nonEmpty && isTopNExpression(windowExpression) => + var topn = Int.MaxValue + val nonTopNConditions = splitConjunctivePredicates(condition).filter { + case LessThan(e: NamedExpression, IntegerLiteral(n)) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n - 1) + false + case LessThanOrEqual(e: NamedExpression, IntegerLiteral(n)) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n) + false + case GreaterThan(IntegerLiteral(n), e: NamedExpression) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n - 1) + false + case GreaterThanOrEqual(IntegerLiteral(n), e: NamedExpression) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n) + false + case EqualTo(e: NamedExpression, IntegerLiteral(n)) + if n == 1 && e.exprId == windowExpression.exprId => + topn = 1 + false + case EqualTo(IntegerLiteral(n), e: NamedExpression) + if n == 1 && e.exprId == windowExpression.exprId => + topn = 1 + false + case _ => true + } + + if (topn > 0 && topn <= topNSortThreshold) { + val strictTopN = isStrictTopN(windowExpression) + val topNSortExec = ColumnarTopNSortExec( + topn, strictTopN, w.partitionSpec, w.orderSpec, sort.global, replaceWithColumnarPlan(sort.child)) + logInfo(s"Columnar Processing for ${topNSortExec.getClass} is currently supported.") + val newCondition = if (nonTopNConditions.isEmpty) { + Literal.TrueLiteral + } else { + nonTopNConditions.reduce(And) + } + val window = ColumnarWindowExec(w.windowExpression, w.partitionSpec, w.orderSpec, topNSortExec) + return ColumnarFilterExec(newCondition, window) + } else { + logInfo{s"topn: ${topn} is bigger than topNSortThreshold: ${topNSortThreshold}."} + val child = replaceWithColumnarPlan(f.child) + return ColumnarFilterExec(f.condition, child) + } + } + } val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarFilterExec(plan.condition, child) + case plan: ExpandExec if enableColumnarExpand => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") @@ -145,7 +281,7 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { join4 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, filter @ ColumnarFilterExec(_, scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _) - ), _, _)), _, _)), _, _)), _, _)) + ), _, _, _)), _, _, _)), _, _, _)), _, _, _)) if checkBhjRightChild( child.asInstanceOf[ColumnarProjectExec].child.children(1) .asInstanceOf[ColumnarBroadcastExchangeExec].child) => @@ -176,7 +312,7 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { proj3 @ ColumnarProjectExec(_, join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, _, filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _)) , _, _)), _, _)) + scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _)) , _, _, _)), _, _, _)) if checkBhjRightChild( child.asInstanceOf[ColumnarProjectExec].child.children(1) .asInstanceOf[ColumnarBroadcastExchangeExec].child) => @@ -205,7 +341,7 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { proj3 @ ColumnarProjectExec(_, join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _)) , _, _)), _, _)) + scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _, _)) , _, _, _)), _, _, _)) if checkBhjRightChild( child.asInstanceOf[ColumnarProjectExec].child.children(1) .asInstanceOf[ColumnarBroadcastExchangeExec].child) => @@ -295,19 +431,219 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { plan.condition, left, right) + // DeduplicateRightSideOfLeftSemiJoin Rule works only for Spark 3.1. + case plan: SortMergeJoinExec if enableColumnarSortMergeJoin && enableDedupLeftSemiJoin => { + plan.joinType match { + case LeftSemi => { + if (plan.condition.isEmpty && plan.left.isInstanceOf[SortExec] && plan.right.isInstanceOf[SortExec] + && plan.right.asInstanceOf[SortExec].child.isInstanceOf[ShuffleExchangeExec]) { + val nextChild = plan.right.asInstanceOf[SortExec].child.asInstanceOf[ShuffleExchangeExec].child + if (nextChild.output.size >= dedupLeftSemiJoinThreshold) { + nextChild match { + case ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _, _, _)) => { + val left = replaceWithColumnarPlan(plan.left) + val val1 = replaceWithColumnarPlan(nextChild.asInstanceOf[ProjectExec]) + val partialAgg = PhysicalAggregation.unapply(Aggregate(nextChild.output, nextChild.output, + new DummyLogicalPlan)) match { + case Some((groupingExpressions, aggExpressions, resultExpressions, _)) + if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) => + ExtendedAggUtils.planPartialAggregateWithoutDistinct( + ExtendedAggUtils.normalizeGroupingExpressions(groupingExpressions), + aggExpressions.map(_.asInstanceOf[AggregateExpression]), + resultExpressions, + val1) + } + + if (partialAgg.isInstanceOf[HashAggregateExec]) { + val newHashAgg = new ColumnarHashAggregateExec( + partialAgg.asInstanceOf[HashAggregateExec].requiredChildDistributionExpressions, + partialAgg.asInstanceOf[HashAggregateExec].groupingExpressions, + partialAgg.asInstanceOf[HashAggregateExec].aggregateExpressions, + partialAgg.asInstanceOf[HashAggregateExec].aggregateAttributes, + partialAgg.asInstanceOf[HashAggregateExec].initialInputBufferOffset, + partialAgg.asInstanceOf[HashAggregateExec].resultExpressions, + val1) + + val newShuffle = new ColumnarShuffleExchangeExec( + plan.right.asInstanceOf[SortExec].child.asInstanceOf[ShuffleExchangeExec].outputPartitioning, + newHashAgg, + plan.right.asInstanceOf[SortExec].child.asInstanceOf[ShuffleExchangeExec].shuffleOrigin + ) + val newSort = new ColumnarSortExec( + plan.right.asInstanceOf[SortExec].sortOrder, + plan.right.asInstanceOf[SortExec].global, + newShuffle, + plan.right.asInstanceOf[SortExec].testSpillFrequency) + ColumnarSortMergeJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + newSort, + plan.isSkewJoin) + } else { + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + if (enableSortMergeJoinFusion && plan.left.isInstanceOf[SortExec] + && plan.right.isInstanceOf[SortExec]) { + val left = replaceWithColumnarPlan(plan.left.asInstanceOf[SortExec].child) + val right = replaceWithColumnarPlan(plan.right.asInstanceOf[SortExec].child) + new ColumnarSortMergeJoinFusionExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + } else { + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + new ColumnarSortMergeJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + } + } + } + case _ => { + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + if (enableSortMergeJoinFusion && plan.left.isInstanceOf[SortExec] + && plan.right.isInstanceOf[SortExec]) { + val left = replaceWithColumnarPlan(plan.left.asInstanceOf[SortExec].child) + val right = replaceWithColumnarPlan(plan.right.asInstanceOf[SortExec].child) + new ColumnarSortMergeJoinFusionExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + } else { + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + new ColumnarSortMergeJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + } + } + } + } else { + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + if (enableSortMergeJoinFusion && plan.left.isInstanceOf[SortExec] && plan.right.isInstanceOf[SortExec]) { + val left = replaceWithColumnarPlan(plan.left.asInstanceOf[SortExec].child) + val right = replaceWithColumnarPlan(plan.right.asInstanceOf[SortExec].child) + new ColumnarSortMergeJoinFusionExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + } else { + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + new ColumnarSortMergeJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + } + } + } else { + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + if (enableSortMergeJoinFusion && plan.left.isInstanceOf[SortExec] && plan.right.isInstanceOf[SortExec]) { + val left = replaceWithColumnarPlan(plan.left.asInstanceOf[SortExec].child) + val right = replaceWithColumnarPlan(plan.right.asInstanceOf[SortExec].child) + new ColumnarSortMergeJoinFusionExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + } else { + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + new ColumnarSortMergeJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + } + } + } + case _ => { + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + if (enableSortMergeJoinFusion && plan.left.isInstanceOf[SortExec] && plan.right.isInstanceOf[SortExec]) { + val left = replaceWithColumnarPlan(plan.left.asInstanceOf[SortExec].child) + val right = replaceWithColumnarPlan(plan.right.asInstanceOf[SortExec].child) + new ColumnarSortMergeJoinFusionExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + } else { + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + new ColumnarSortMergeJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + } + } + } + } case plan: SortMergeJoinExec if enableColumnarSortMergeJoin => logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - val left = replaceWithColumnarPlan(plan.left) - val right = replaceWithColumnarPlan(plan.right) - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - new ColumnarSortMergeJoinExec( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.condition, - left, - right, - plan.isSkewJoin) + if (enableSortMergeJoinFusion && plan.left.isInstanceOf[SortExec] && plan.right.isInstanceOf[SortExec]) { + val left = replaceWithColumnarPlan(plan.left.asInstanceOf[SortExec].child) + val right = replaceWithColumnarPlan(plan.right.asInstanceOf[SortExec].child) + new ColumnarSortMergeJoinFusionExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + } else { + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + new ColumnarSortMergeJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + } case plan: SortExec if enableColumnarSort => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") @@ -315,7 +651,16 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { case plan: WindowExec if enableColumnarWindow => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, child) + child match { + case ColumnarSortExec(sortOrder, _, sortChild, _) => + if (Seq(plan.partitionSpec.map(SortOrder(_, Ascending)) ++ plan.orderSpec) == Seq(sortOrder)) { + ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, sortChild) + } else { + ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, child) + } + case _ => + ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, child) + } case plan: UnionExec if enableColumnarUnion => val children = plan.children.map(replaceWithColumnarPlan) logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") @@ -377,7 +722,23 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { var isSupportAdaptive: Boolean = true def apply(plan: SparkPlan): SparkPlan = { - replaceWithColumnarPlan(plan) + handleColumnarToRowParitalFetch(replaceWithColumnarPlan(plan)) + } + + private def handleColumnarToRowParitalFetch(plan: SparkPlan): SparkPlan = { + // simple check plan tree have OmniColumnarToRow and no LimitExec and TakeOrderedAndProjectExec plan + val noParitalFetch = if (plan.find(_.isInstanceOf[OmniColumnarToRowExec]).isDefined) { + (!plan.find(node => + node.isInstanceOf[LimitExec] || node.isInstanceOf[TakeOrderedAndProjectExec] || + node.isInstanceOf[SortMergeJoinExec]).isDefined) + } else { + false + } + val newPlan = plan.transformUp { + case c: OmniColumnarToRowExec if noParitalFetch => + c.copy(c.child, false) + } + newPlan } def setAdaptiveSupport(enable: Boolean): Unit = { isSupportAdaptive = enable } @@ -469,5 +830,7 @@ class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { logInfo("Using BoostKit Spark Native Sql Engine Extension to Speed Up Your Queries.") extensions.injectColumnar(session => ColumnarOverrideRules(session)) extensions.injectPlannerStrategy(_ => ShuffleJoinStrategy) + extensions.injectOptimizerRule(_ => DelayCartesianProduct) + extensions.injectOptimizerRule(_ => HeuristicJoinReorder) } -} \ No newline at end of file +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index 58eef4125480f04cdb7e80acf247b1525c6e7313..fb45820be6c046f10860defca83055d6d99cd069 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -55,6 +55,13 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { val enableColumnarSort: Boolean = conf.getConfString("spark.omni.sql.columnar.sort", "true").toBoolean + // enable or disable topNSort + val enableColumnarTopNSort: Boolean = + conf.getConfString("spark.omni.sql.columnar.topnsort", "true").toBoolean + + val topNSortThreshold: Int = + conf.getConfString("spark.omni.sql.columnar.topnsortthreshold", "100").toInt + val enableColumnarUnion: Boolean = conf.getConfString("spark.omni.sql.columnar.union", "true").toBoolean @@ -75,11 +82,33 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { .getConfString("spark.omni.sql.columnar.broadcastJoin", "true") .toBoolean + // enable or disable share columnar BroadcastHashJoin hashtable + val enableShareBroadcastJoinHashTable: Boolean = conf + .getConfString("spark.omni.sql.columnar.broadcastJoin.sharehashtable", "true") + .toBoolean + + // enable or disable heuristic join reorder + val enableHeuristicJoinReorder: Boolean = + conf.getConfString("spark.omni.sql.columnar.heuristicJoinReorder", "true").toBoolean + + // enable or disable delay cartesian product + val enableDelayCartesianProduct: Boolean = + conf.getConfString("spark.omni.sql.columnar.delayCartesianProduct", "true").toBoolean + // enable native table scan val enableColumnarFileScan: Boolean = conf .getConfString("spark.omni.sql.columnar.nativefilescan", "true") .toBoolean + // enable native table scan + val enableOrcNativeFileScan: Boolean = conf + .getConfString("spark.omni.sql.columnar.orcNativefilescan", "true") + .toBoolean + + val enableSortMergeJoinFusion: Boolean = conf + .getConfString("spark.omni.sql.columnar.sortMergeJoin.fusion", "false") + .toBoolean + val enableColumnarSortMergeJoin: Boolean = conf .getConfString("spark.omni.sql.columnar.sortMergeJoin", "true") .toBoolean @@ -131,7 +160,11 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { // columnar sort spill threshold val columnarSortSpillRowThreshold: Integer = - conf.getConfString("spark.omni.sql.columnar.sortSpill.rowThreshold", "200000").toInt + conf.getConfString("spark.omni.sql.columnar.sortSpill.rowThreshold", Integer.MAX_VALUE.toString).toInt + + // columnar sort spill threshold - Percentage of memory usage, associate with the "spark.memory.offHeap" together + val columnarSortSpillMemPctThreshold: Integer = + conf.getConfString("spark.omni.sql.columnar.sortSpill.memFraction", "90").toInt // columnar sort spill dir disk reserve Size, default 10GB val columnarSortSpillDirDiskReserveSize:Long = @@ -148,7 +181,7 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { .toBoolean val enableFusion: Boolean = conf - .getConfString("spark.omni.sql.columnar.fusion", "true") + .getConfString("spark.omni.sql.columnar.fusion", "false") .toBoolean // Pick columnar shuffle hash join if one side join count > = 0 to build local hash map, and is @@ -163,9 +196,21 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { val maxRowCount = conf.getConfString("spark.sql.columnar.maxRowCount", "20000").toInt + val mergedBatchThreshold = + conf.getConfString("spark.sql.columnar.mergedBatchThreshold", "100").toInt + val enableColumnarUdf: Boolean = conf.getConfString("spark.omni.sql.columnar.udf", "true").toBoolean val enableOmniExpCheck : Boolean = conf.getConfString("spark.omni.sql.omniExp.check", "true").toBoolean + + val enableColumnarProjectFusion : Boolean = conf.getConfString("spark.omni.sql.columnar.projectFusion", "true").toBoolean + + // enable or disable deduplicate the right side of left semi join + val enableDedupLeftSemiJoin: Boolean = + conf.getConfString("spark.omni.sql.columnar.dedupLeftSemiJoin", "true").toBoolean + + val dedupLeftSemiJoinThreshold: Int = + conf.getConfString("spark.omni.sql.columnar.dedupLeftSemiJoinThreshold", "3").toInt } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala index e773a780dcfa6e66dc8c96e97e29d80f59703e73..9d7f844bcc19601ac065083b988085c340631ad3 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala @@ -24,7 +24,7 @@ import nova.hetu.omniruntime.`type`.DataType.DataTypeId * @since 2022/4/15 */ object Constant { - val DEFAULT_STRING_TYPE_LENGTH = 2000 + val DEFAULT_STRING_TYPE_LENGTH = 50 val OMNI_VARCHAR_TYPE: String = DataTypeId.OMNI_VARCHAR.ordinal().toString val OMNI_SHOR_TYPE: String = DataTypeId.OMNI_SHORT.ordinal().toString val OMNI_INTEGER_TYPE: String = DataTypeId.OMNI_INT.ordinal().toString diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala index 19da63cafad7b3bf2f0e1a863060b2eedae2935f..9a45854da24375e1db310356db943df8f245fa49 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala @@ -63,22 +63,43 @@ object ShuffleJoinStrategy extends Strategy buildRight = true } - getBuildSide( - canBuildShuffledHashJoinLeft(joinType) && buildLeft, - canBuildShuffledHashJoinRight(joinType) && buildRight, - left, - right - ).map { - buildSide => - Seq(joins.ShuffledHashJoinExec( - leftKeys, - rightKeys, - joinType, - buildSide, - nonEquiCond, - planLater(left), - planLater(right))) - }.getOrElse(Nil) + // use cbo statistics to take effect if CBO is enable + if (conf.cboEnabled) { + getShuffleHashJoinBuildSide(left, + right, + joinType, + hint, + false, + conf) + .map { + buildSide => + Seq(joins.ShuffledHashJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + nonEquiCond, + planLater(left), + planLater(right))) + }.getOrElse(Nil) + } else { + getBuildSide( + canBuildShuffledHashJoinLeft(joinType) && buildLeft, + canBuildShuffledHashJoinRight(joinType) && buildRight, + left, + right + ).map { + buildSide => + Seq(joins.ShuffledHashJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + nonEquiCond, + planLater(left), + planLater(right))) + }.getOrElse(Nil) + } } else { Nil } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index 170393144eddc51db16bf41980cd1d7377c0cca9..ad1e0511e87ea12b33e6bc7530a92351938fe36b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -30,9 +30,11 @@ import nova.hetu.omniruntime.operator.OmniExprVerify import com.huawei.boostkit.spark.ColumnarPluginConfig import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.util.CharVarcharUtils.getRawTypeString +import org.apache.spark.sql.execution import org.apache.spark.sql.hive.HiveUdfAdaptorUtil import org.apache.spark.sql.types.{BooleanType, DataType, DateType, Decimal, DecimalType, DoubleType, IntegerType, LongType, Metadata, ShortType, StringType} @@ -40,7 +42,6 @@ import java.util.Locale import scala.collection.mutable object OmniExpressionAdaptor extends Logging { - def getRealExprId(expr: Expression): ExprId = { expr match { case alias: Alias => getRealExprId(alias.child) @@ -299,12 +300,17 @@ object OmniExpressionAdaptor extends Logging { } private def unsupportedCastCheck(expr: Expression, cast: Cast): Unit = { - def isDecimalOrStringType(dataType: DataType): Boolean = (dataType.isInstanceOf[DecimalType]) || (dataType.isInstanceOf[StringType]) + def isDecimalOrStringType(dataType: DataType): Boolean = (dataType.isInstanceOf[DecimalType]) || (dataType.isInstanceOf[StringType] || (dataType.isInstanceOf[DateType])) // not support Cast(string as !(decimal/string)) and Cast(!(decimal/string) as string) if ((cast.dataType.isInstanceOf[StringType] && !isDecimalOrStringType(cast.child.dataType)) || (!isDecimalOrStringType(cast.dataType) && cast.child.dataType.isInstanceOf[StringType])) { throw new UnsupportedOperationException(s"Unsupported expression: $expr") } + + // not support Cast(double as decimal) + if (cast.dataType.isInstanceOf[DecimalType] && cast.child.dataType.isInstanceOf[DoubleType]) { + throw new UnsupportedOperationException(s"Unsupported expression: $expr") + } } def toOmniLiteral(literal: Literal): String = { @@ -325,6 +331,20 @@ object OmniExpressionAdaptor extends Logging { exprsIndexMap: Map[ExprId, Int], returnDatatype: DataType): String = { expr match { + case subquery: execution.ScalarSubquery => + var result: Any = null + try { + result = subquery.eval(InternalRow.empty) + } catch { + case e: IllegalArgumentException => logDebug(e.getMessage) + } + if (result == null) { + ("{\"exprType\":\"LITERAL\",\"dataType\":%s, \"isNull\":true,\"value\":%s}") + .format(sparkTypeToOmniExpJsonType(subquery.dataType), result) + } else { + val literal = Literal(result, subquery.dataType) + toOmniJsonLiteral(literal) + } case unscaledValue: UnscaledValue => ("{\"exprType\":\"FUNCTION\",\"returnType\":%s," + "\"function_name\":\"UnscaledValue\", \"arguments\":[%s]}") @@ -500,6 +520,11 @@ object OmniExpressionAdaptor extends Logging { .format(sparkTypeToOmniExpJsonType(lower.dataType), rewriteToOmniJsonExpressionLiteral(lower.child, exprsIndexMap)) + case upper: Upper => + "{\"exprType\":\"FUNCTION\",\"returnType\":%s,\"function_name\":\"upper\", \"arguments\":[%s]}" + .format(sparkTypeToOmniExpJsonType(upper.dataType), + rewriteToOmniJsonExpressionLiteral(upper.child, exprsIndexMap)) + case length: Length => "{\"exprType\":\"FUNCTION\",\"returnType\":%s,\"function_name\":\"length\", \"arguments\":[%s]}" .format(sparkTypeToOmniExpJsonType(length.dataType), @@ -548,6 +573,11 @@ object OmniExpressionAdaptor extends Logging { case concat: Concat => getConcatJsonStr(concat, exprsIndexMap) + case round: Round => + "{\"exprType\":\"FUNCTION\",\"returnType\":%s,\"function_name\":\"round\", \"arguments\":[%s,%s]}" + .format(sparkTypeToOmniExpJsonType(round.dataType), + rewriteToOmniJsonExpressionLiteral(round.child, exprsIndexMap), + rewriteToOmniJsonExpressionLiteral(round.scale, exprsIndexMap)) case attr: Attribute => toOmniJsonAttribute(attr, exprsIndexMap(attr.exprId)) case _ => if (HiveUdfAdaptorUtil.isHiveUdf(expr) && ColumnarPluginConfig.getSessionConf.enableColumnarUdf) { @@ -651,19 +681,19 @@ object OmniExpressionAdaptor extends Logging { } } - def toOmniAggFunType(agg: AggregateExpression, isHashAgg: Boolean = false, isFinal: Boolean = false): FunctionType = { + def toOmniAggFunType(agg: AggregateExpression, isHashAgg: Boolean = false, isMergeCount: Boolean = false): FunctionType = { agg.aggregateFunction match { case Sum(_) => OMNI_AGGREGATION_TYPE_SUM case Max(_) => OMNI_AGGREGATION_TYPE_MAX case Average(_) => OMNI_AGGREGATION_TYPE_AVG case Min(_) => OMNI_AGGREGATION_TYPE_MIN case Count(Literal(1, IntegerType) :: Nil) | Count(ArrayBuffer(Literal(1, IntegerType))) => - if (isFinal) { + if (isMergeCount) { OMNI_AGGREGATION_TYPE_COUNT_COLUMN } else { OMNI_AGGREGATION_TYPE_COUNT_ALL } - case Count(_) => OMNI_AGGREGATION_TYPE_COUNT_COLUMN + case Count(_) if agg.aggregateFunction.children.size == 1 => OMNI_AGGREGATION_TYPE_COUNT_COLUMN case First(_, true) => OMNI_AGGREGATION_TYPE_FIRST_IGNORENULL case First(_, false) => OMNI_AGGREGATION_TYPE_FIRST_INCLUDENULL case _ => throw new UnsupportedOperationException(s"Unsupported aggregate function: $agg") @@ -954,12 +984,16 @@ object OmniExpressionAdaptor extends Logging { joinType match { case FullOuter => OMNI_JOIN_TYPE_FULL - case _: InnerLike => + case Inner => OMNI_JOIN_TYPE_INNER case LeftOuter => OMNI_JOIN_TYPE_LEFT case RightOuter => OMNI_JOIN_TYPE_RIGHT + case LeftSemi => + OMNI_JOIN_TYPE_LEFT_SEMI + case LeftAnti => + OMNI_JOIN_TYPE_LEFT_ANTI case _ => throw new UnsupportedOperationException(s"Join-type[$joinType] is not supported.") } @@ -983,4 +1017,15 @@ object OmniExpressionAdaptor extends Logging { } true } + + def isSimpleProjectForAll(project: NamedExpression): Boolean = { + project match { + case attribute: AttributeReference => + true + case alias: Alias => + alias.child.isInstanceOf[AttributeReference] + case _ => + false + } + } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala index abbdcb820a582f277fd79de459aa6aca29524ca2..ed99f6b4311a48492438095a87d450f7d9d89a5a 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala @@ -17,21 +17,24 @@ package com.huawei.boostkit.spark.util -import java.util.concurrent.TimeUnit.NANOSECONDS +import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP +import java.util.concurrent.TimeUnit.NANOSECONDS import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ +import nova.hetu.omniruntime.constants.FunctionType import nova.hetu.omniruntime.operator.OmniOperator -import nova.hetu.omniruntime.operator.config.OverflowConfig +import nova.hetu.omniruntime.operator.aggregator.{OmniAggregationWithExprOperatorFactory, OmniHashAggregationWithExprOperatorFactory} +import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} import nova.hetu.omniruntime.vector._ - -import org.apache.spark.sql.catalyst.expressions.{Attribute, ExprId, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, ExprId, NamedExpression, SortOrder} import org.apache.spark.sql.execution.datasources.orc.OrcColumnVector import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.vectorized.{OmniColumnVector, OnHeapColumnVector} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} +import scala.collection.mutable.ListBuffer import java.util object OmniAdaptorUtil { @@ -43,16 +46,14 @@ object OmniAdaptorUtil { val input = new Array[Vec](cb.numCols()) for (i <- 0 until cb.numCols()) { val omniVec: Vec = cb.column(i) match { - case vector: OrcColumnVector => - transColumnVector(vector, cb.numRows()) - case vector: OnHeapColumnVector => - transColumnVector(vector, cb.numRows()) case vector: OmniColumnVector => if (!isSlice) { vector.getVec } else { vector.getVec.slice(0, cb.numRows()) } + case vector: ColumnVector => + transColumnVector(vector, cb.numRows()) case _ => throw new UnsupportedOperationException("unsupport column vector!") } @@ -122,7 +123,7 @@ object OmniAdaptorUtil { } offsets(i + 1) = totalSize } - val vec = new VarcharVec(totalSize, columnSize) + val vec = new VarcharVec(columnSize) val values = new Array[Byte](totalSize) for (i <- 0 until columnSize) { if (null != columnVector.getUTF8String(i)) { @@ -272,4 +273,97 @@ object OmniAdaptorUtil { else OverflowConfig.OverflowConfigId.OVERFLOW_CONFIG_NULL } + + def getAggOperator(groupingExpressions: Seq[NamedExpression], + omniGroupByChanel: Array[String], + omniAggChannels: Array[Array[String]], + omniAggChannelsFilter: Array[String], + omniSourceTypes: Array[nova.hetu.omniruntime.`type`.DataType], + omniAggFunctionTypes: Array[FunctionType], + omniAggOutputTypes: Array[Array[nova.hetu.omniruntime.`type`.DataType]], + omniInputRaws: Array[Boolean], + omniOutputPartials: Array[Boolean]): OmniOperator = { + var operator: OmniOperator = null + if (groupingExpressions.nonEmpty) { + operator = new OmniHashAggregationWithExprOperatorFactory( + omniGroupByChanel, + omniAggChannels, + omniAggChannelsFilter, + omniSourceTypes, + omniAggFunctionTypes, + omniAggOutputTypes, + omniInputRaws, + omniOutputPartials, + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)).createOperator + } else { + operator = new OmniAggregationWithExprOperatorFactory( + omniGroupByChanel, + omniAggChannels, + omniAggChannelsFilter, + omniSourceTypes, + omniAggFunctionTypes, + omniAggOutputTypes, + omniInputRaws, + omniOutputPartials, + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)).createOperator + } + operator + } + + def pruneOutput(output: Seq[Attribute], projectList: Seq[NamedExpression]): Seq[Attribute] = { + if (projectList.nonEmpty) { + val projectOutput = ListBuffer[Attribute]() + for (project <- projectList) { + for (col <- output) { + if (col.exprId.equals(getProjectAliasExprId(project))) { + projectOutput += col + } + } + } + projectOutput + } else { + output + } + } + + def getIndexArray(output: Seq[Attribute], projectList: Seq[NamedExpression]): Array[Int] = { + if (projectList.nonEmpty) { + val indexList = ListBuffer[Int]() + for (project <- projectList) { + for (i <- output.indices) { + val col = output(i) + if (col.exprId.equals(getProjectAliasExprId(project))) { + indexList += i + } + } + } + indexList.toArray + } else { + output.indices.toArray + } + } + + def reorderVecs(prunedOutput: Seq[Attribute], projectList: Seq[NamedExpression], resultVecs: Array[nova.hetu.omniruntime.vector.Vec], vecs: Array[OmniColumnVector]) = { + for (index <- projectList.indices) { + val project = projectList(index) + for (i <- prunedOutput.indices) { + val col = prunedOutput(i) + if (col.exprId.equals(getProjectAliasExprId(project))) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + } + } + } + } + + def getProjectAliasExprId(project: NamedExpression): ExprId = { + project match { + case alias: Alias => + // The condition of parameter is restricted. If parameter type is alias, its child type must be attributeReference. + alias.child.asInstanceOf[AttributeReference].exprId + case _ => + project.exprId + } + } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorder.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorder.scala new file mode 100644 index 0000000000000000000000000000000000000000..d038099a9a332ee5cb2d32f7573bf4e1073c8d0c --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorder.scala @@ -0,0 +1,358 @@ +/* + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.annotation.tailrec +import scala.collection.mutable + +import com.huawei.boostkit.spark.ColumnarPluginConfig + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, EqualNullSafe, EqualTo, Expression, IsNotNull, PredicateHelper} +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.util.sideBySide + + + + +/** + * Move all cartesian products to the root of the plan + */ +object DelayCartesianProduct extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Extract cliques from the input plans. + * A cliques is a sub-tree(sub-plan) which doesn't have any join with other sub-plan. + * The input plans are picked from left to right + * , until we can't find join condition in the remaining plans. + * The same logic is applied to the remaining plans, until all plans are picked. + * This function can produce a left-deep tree or a bushy tree. + * + * @param input a list of LogicalPlans to inner join and the type of inner join. + * @param conditions a list of condition for join. + */ + private def extractCliques(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) + : Seq[(LogicalPlan, InnerLike)] = { + if (input.size == 1) { + input + } else { + val (leftPlan, leftInnerJoinType) :: linearSeq = input + // discover the initial join that contains at least one join condition + val conditionalOption = linearSeq.find { planJoinPair => + val plan = planJoinPair._1 + val refs = leftPlan.outputSet ++ plan.outputSet + conditions + .filterNot(l => l.references.nonEmpty && canEvaluate(l, leftPlan)) + .filterNot(r => r.references.nonEmpty && canEvaluate(r, plan)) + .exists(_.references.subsetOf(refs)) + } + + if (conditionalOption.isEmpty) { + Seq((leftPlan, leftInnerJoinType)) ++ extractCliques(linearSeq, conditions) + } else { + val (rightPlan, rightInnerJoinType) = conditionalOption.get + + val joinedRefs = leftPlan.outputSet ++ rightPlan.outputSet + val (joinConditions, otherConditions) = conditions.partition( + e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e)) + val joined = Join(leftPlan, rightPlan, rightInnerJoinType, + joinConditions.reduceLeftOption(And), JoinHint.NONE) + + // must not make reference to the same logical plan + extractCliques(Seq((joined, Inner)) + ++ linearSeq.filterNot(_._1 eq rightPlan), otherConditions) + } + } + } + + /** + * Link cliques by cartesian product + * + * @param input + * @return + */ + private def linkCliques(input: Seq[(LogicalPlan, InnerLike)]) + : LogicalPlan = { + if (input.length == 1) { + input.head._1 + } else if (input.length == 2) { + val ((left, innerJoinType1), (right, innerJoinType2)) = (input(0), input(1)) + val joinType = resetJoinType(innerJoinType1, innerJoinType2) + Join(left, right, joinType, None, JoinHint.NONE) + } else { + val (left, innerJoinType1) :: (right, innerJoinType2) :: rest = input + val joinType = resetJoinType(innerJoinType1, innerJoinType2) + linkCliques(Seq((Join(left, right, joinType, None, JoinHint.NONE), joinType)) ++ rest) + } + } + + /** + * This is to reset the join type before reordering. + * + * @param leftJoinType + * @param rightJoinType + * @return + */ + private def resetJoinType(leftJoinType: InnerLike, rightJoinType: InnerLike): InnerLike = { + (leftJoinType, rightJoinType) match { + case (_, Cross) | (Cross, _) => Cross + case _ => Inner + } + } + + def apply(plan: LogicalPlan): LogicalPlan = { + if (!ColumnarPluginConfig.getSessionConf.enableDelayCartesianProduct) { + return plan + } + + // Reorder joins only when there are cartesian products. + var existCartesianProduct = false + plan foreach { + case Join(_, _, _: InnerLike, None, _) => existCartesianProduct = true + case _ => + } + + if (existCartesianProduct) { + plan.transform { + case originalPlan@ExtractFiltersAndInnerJoins(input, conditions) + if input.size > 2 && conditions.nonEmpty => + val cliques = extractCliques(input, conditions) + val reorderedPlan = linkCliques(cliques) + + reorderedPlan match { + // Generate a bushy tree after reordering. + case ExtractFiltersAndInnerJoinsForBushy(_, joinConditions) => + val primalConditions = conditions.flatMap(splitConjunctivePredicates) + val reorderedConditions = joinConditions.flatMap(splitConjunctivePredicates).toSet + val missingConditions = primalConditions.filterNot(reorderedConditions.contains) + if (missingConditions.nonEmpty) { + val comparedPlans = + sideBySide(originalPlan.treeString, reorderedPlan.treeString).mkString("\n") + logWarning("There are missing conditions after reordering, falling back to the " + + s"original plan. == Comparing two plans ===\n$comparedPlans") + originalPlan + } else { + reorderedPlan + } + case _ => throw new AnalysisException( + s"There is no join node in the plan, this should not happen: $reorderedPlan") + } + } + } else { + plan + } + } +} + +/** + * Firstly, Heuristic reorder join need to execute small joins with filters + * , which can reduce intermediate results + */ +object HeuristicJoinReorder extends Rule[LogicalPlan] + with PredicateHelper with JoinSelectionHelper { + + /** + * Join a list of plans together and push down the conditions into them. + * The joined plan are picked from left to right, thus the final result is a left-deep tree. + * + * @param input a list of LogicalPlans to inner join and the type of inner join. + * @param conditions a list of condition for join. + */ + @tailrec + final def createReorderJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) + : LogicalPlan = { + assert(input.size >= 2) + if (input.size == 2) { + val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin) + val ((leftPlan, leftJoinType), (rightPlan, rightJoinType)) = (input(0), input(1)) + val innerJoinType = (leftJoinType, rightJoinType) match { + case (Inner, Inner) => Inner + case (_, _) => Cross + } + // Set the join node ordered so that we don't need to transform them again. + val orderJoin = OrderedJoin(leftPlan, rightPlan, innerJoinType, joinConditions.reduceLeftOption(And)) + if (others.nonEmpty) { + Filter(others.reduceLeft(And), orderJoin) + } else { + orderJoin + } + } else { + val (left, _) :: rest = input.toList + val candidates = rest.filter { planJoinPair => + val plan = planJoinPair._1 + // 1. it has join conditions with the left node + // 2. it has a filter + // 3. it can be broadcast + val isEqualJoinCondition = conditions.flatMap { + case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => None + case EqualNullSafe(l, r) if l.references.isEmpty || r.references.isEmpty => None + case e@EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, plan) => Some(e) + case e@EqualTo(l, r) if canEvaluate(l, plan) && canEvaluate(r, left) => Some(e) + case e@EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, plan) => Some(e) + case e@EqualNullSafe(l, r) if canEvaluate(l, plan) && canEvaluate(r, left) => Some(e) + case _ => None + }.nonEmpty + + val hasFilter = plan match { + case f: Filter if hasValuableCondition(f.condition) => true + case Project(_, f: Filter) if hasValuableCondition(f.condition) => true + case _ => false + } + + isEqualJoinCondition && hasFilter + } + val (right, innerJoinType) = if (candidates.nonEmpty) { + candidates.minBy(_._1.stats.sizeInBytes) + } else { + rest.head + } + + val joinedRefs = left.outputSet ++ right.outputSet + val selectedJoinConditions = mutable.HashSet.empty[Expression] + val (joinConditions, others) = conditions.partition { e => + // If there are semantically equal conditions, they should come from two different joins. + // So we should not put them into one join. + if (!selectedJoinConditions.contains(e.canonicalized) && e.references.subsetOf(joinedRefs) + && canEvaluateWithinJoin(e)) { + selectedJoinConditions.add(e.canonicalized) + true + } else { + false + } + } + // Set the join node ordered so that we don't need to transform them again. + val joined = OrderedJoin(left, right, innerJoinType, joinConditions.reduceLeftOption(And)) + + // should not have reference to same logical plan + createReorderJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others) + } + } + + private def hasValuableCondition(condition: Expression): Boolean = { + val conditions = splitConjunctivePredicates(condition) + !conditions.forall(_.isInstanceOf[IsNotNull]) + } + + def apply(plan: LogicalPlan): LogicalPlan = { + if (ColumnarPluginConfig.getSessionConf.enableHeuristicJoinReorder) { + val newPlan = plan.transform { + case p@ExtractFiltersAndInnerJoinsByIgnoreProjects(input, conditions) + if input.size > 2 && conditions.nonEmpty => + val reordered = createReorderJoin(input, conditions) + if (p.sameOutput(reordered)) { + reordered + } else { + // Reordering the joins have changed the order of the columns. + // Inject a projection to make sure we restore to the expected ordering. + Project(p.output, reordered) + } + } + + // After reordering is finished, convert OrderedJoin back to Join + val result = newPlan.transformDown { + case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond, JoinHint.NONE) + } + if (!result.resolved) { + // In some special cases related to subqueries, we find that after reordering, + val comparedPlans = sideBySide(plan.treeString, result.treeString).mkString("\n") + logWarning("The structural integrity of the plan is broken, falling back to the " + + s"original plan. == Comparing two plans ===\n$comparedPlans") + plan + } else { + result + } + } else { + plan + } + } +} + +/** + * This is different from [[ExtractFiltersAndInnerJoins]] in that it can collect filters and + * inner joins by ignoring projects on top of joins, which are produced by column pruning. + */ +private object ExtractFiltersAndInnerJoinsByIgnoreProjects extends PredicateHelper { + + /** + * Flatten all inner joins, which are next to each other. + * Return a list of logical plans to be joined with a boolean for each plan indicating if it + * was involved in an explicit cross join. Also returns the entire list of join conditions for + * the left-deep tree. + */ + def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner) + : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match { + case Join(left, right, joinType: InnerLike, cond, hint) if hint == JoinHint.NONE => + val (plans, conditions) = flattenJoin(left, joinType) + (plans ++ Seq((right, joinType)), conditions ++ + cond.toSeq.flatMap(splitConjunctivePredicates)) + case Filter(filterCondition, j@Join(_, _, _: InnerLike, _, hint)) if hint == JoinHint.NONE => + val (plans, conditions) = flattenJoin(j) + (plans, conditions ++ splitConjunctivePredicates(filterCondition)) + case Project(projectList, child) + if projectList.forall(_.isInstanceOf[Attribute]) => flattenJoin(child) + + case _ => (Seq((plan, parentJoinType)), Seq.empty) + } + + def unapply(plan: LogicalPlan): Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])] + = plan match { + case f@Filter(_, Join(_, _, _: InnerLike, _, _)) => + Some(flattenJoin(f)) + case j@Join(_, _, _, _, hint) if hint == JoinHint.NONE => + Some(flattenJoin(j)) + case _ => None + } +} + +private object ExtractFiltersAndInnerJoinsForBushy extends PredicateHelper { + + /** + * This function works for both left-deep and bushy trees. + * + * @param plan + * @param parentJoinType + * @return + */ + def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner) + : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match { + case Join(left, right, joinType: InnerLike, cond, _) => + val (lPlans, lConds) = flattenJoin(left, joinType) + val (rPlans, rConds) = flattenJoin(right, joinType) + (lPlans ++ rPlans, lConds ++ rConds ++ cond.toSeq) + + case Filter(filterCondition, j@Join(_, _, _: InnerLike, _, _)) => + val (plans, conditions) = flattenJoin(j) + (plans, conditions ++ splitConjunctivePredicates(filterCondition)) + + case _ => (Seq((plan, parentJoinType)), Seq()) + } + + def unapply(plan: LogicalPlan): Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])] = { + plan match { + case f@Filter(_, Join(_, _, _: InnerLike, _, _)) => + Some(flattenJoin(f)) + case j@Join(_, _, _, _, _) => + Some(flattenJoin(j)) + case _ => None + } + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala index 3769441cf7b315638a12ef3dbc8916bb7b96420c..d137388ab3c41c3ee103ac974cb594990379d394 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala @@ -42,6 +42,7 @@ case class ColumnarBroadcastExchangeAdaptorExec(child: SparkPlan, numPartitions: override def doExecute(): RDD[InternalRow] = { val numOutputRows: SQLMetric = longMetric("numOutputRows") val numOutputBatches: SQLMetric = longMetric("numOutputBatches") + val processTime: SQLMetric = longMetric("processTime") val inputRdd: BroadcastColumnarRDD = BroadcastColumnarRDD( sparkContext, metrics, @@ -49,7 +50,7 @@ case class ColumnarBroadcastExchangeAdaptorExec(child: SparkPlan, numPartitions: child.executeBroadcast(), StructType.fromAttributes(child.output)) inputRdd.mapPartitions { batches => - ColumnarBatchToInternalRow.convert(output, batches, numOutputRows, numOutputBatches) + ColumnarBatchToInternalRow.convert(output, batches, numOutputRows, numOutputBatches, processTime) } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala index 72d1aae05d8f9e4a22a7f0bc17e68aca8b157d74..307c72d968322f80e0c28cab463e5231712113d5 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.util.concurrent._ +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.sparkTypeToOmniType import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs import nova.hetu.omniruntime.vector.VecBatch import nova.hetu.omniruntime.vector.serialize.VecBatchSerializerFactory @@ -62,6 +63,11 @@ class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) @transient private val timeout: Long = SQLConf.get.broadcastTimeout + def buildCheck(): Unit = { + child.output.map( + exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + } + @transient override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala index 92c6b6145c0c63c0ebd1821f087ede91f2e3d8e2..7144b565e4d8ab526d26d4cbb36bb92cdee2ce77 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import nova.hetu.omniruntime.vector.Vec +import java.util.concurrent.TimeUnit.NANOSECONDS import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer @@ -34,6 +34,8 @@ import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OmniColum import org.apache.spark.sql.types.{BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DecimalType, DoubleType, IntegerType, LongType, ShortType, StringType, StructType, TimestampType} import org.apache.spark.sql.vectorized.ColumnarBatch +import nova.hetu.omniruntime.vector.Vec + /** * Holds a user defined rule that can be used to inject columnar implementations of various * operators in the plan. The [[preColumnarTransitions]] [[Rule]] can be used to replace @@ -62,9 +64,9 @@ trait ColumnarToRowTransition extends UnaryExecNode * Provides an optimized set of APIs to append row based data to an array of * [[WritableColumnVector]]. */ -private[execution] class RowToColumnConverter(schema: StructType) extends Serializable { +private[execution] class OmniRowToColumnConverter(schema: StructType) extends Serializable { private val converters = schema.fields.map { - f => RowToColumnConverter.getConverterForType(f.dataType, f.nullable) + f => OmniRowToColumnConverter.getConverterForType(f.dataType, f.nullable) } final def convert(row: InternalRow, vectors: Array[WritableColumnVector]): Unit = { @@ -80,7 +82,7 @@ private[execution] class RowToColumnConverter(schema: StructType) extends Serial * Provides an optimized set of APIs to extract a column from a row and append it to a * [[WritableColumnVector]]. */ -private object RowToColumnConverter { +private object OmniRowToColumnConverter { SparkMemoryUtils.init() private abstract class TypeConverter extends Serializable { @@ -226,13 +228,15 @@ case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransiti override lazy val metrics: Map[String, SQLMetric] = Map( "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), - "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches") + "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches"), + "rowToOmniColumnarTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in row to OmniColumnar") ) override def doExecuteColumnar(): RDD[ColumnarBatch] = { val enableOffHeapColumnVector = sqlContext.conf.offHeapColumnVectorEnabled val numInputRows = longMetric("numInputRows") val numOutputBatches = longMetric("numOutputBatches") + val rowToOmniColumnarTime = longMetric("rowToOmniColumnarTime") // Instead of creating a new config we are reusing columnBatchSize. In the future if we do // combine with some of the Arrow conversion tools we will need to unify some of the configs. val numRows = conf.columnBatchSize @@ -242,13 +246,14 @@ case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransiti child.execute().mapPartitionsInternal { rowIterator => if (rowIterator.hasNext) { new Iterator[ColumnarBatch] { - private val converters = new RowToColumnConverter(localSchema) + private val converters = new OmniRowToColumnConverter(localSchema) override def hasNext: Boolean = { rowIterator.hasNext } override def next(): ColumnarBatch = { + val startTime = System.nanoTime() val vectors: Seq[WritableColumnVector] = OmniColumnVector.allocateColumns(numRows, localSchema, true) val cb: ColumnarBatch = new ColumnarBatch(vectors.toArray) @@ -268,6 +273,7 @@ case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransiti cb.setNumRows(rowCount) numInputRows += rowCount numOutputBatches += 1 + rowToOmniColumnarTime += NANOSECONDS.toMillis(System.nanoTime() - startTime) cb } } @@ -279,7 +285,8 @@ case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransiti } -case class OmniColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransition { +case class OmniColumnarToRowExec(child: SparkPlan, + mayPartialFetch: Boolean = true) extends ColumnarToRowTransition { assert(child.supportsColumnar) override def nodeName: String = "OmniColumnarToRow" @@ -292,17 +299,27 @@ case class OmniColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransiti override lazy val metrics: Map[String, SQLMetric] = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches") + "numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches"), + "omniColumnarToRowTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omniColumnar to row") ) + override def verboseStringWithOperatorId(): String = { + s""" + |$formattedNodeName + |$simpleStringWithNodeId + |${ExplainUtils.generateFieldString("mayPartialFetch", String.valueOf(mayPartialFetch))} + |""".stripMargin + } + override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val numInputBatches = longMetric("numInputBatches") + val omniColumnarToRowTime = longMetric("omniColumnarToRowTime") // This avoids calling `output` in the RDD closure, so that we don't need to include the entire // plan (this) in the closure. val localOutput = this.output child.executeColumnar().mapPartitionsInternal { batches => - ColumnarBatchToInternalRow.convert(localOutput, batches, numOutputRows, numInputBatches) + ColumnarBatchToInternalRow.convert(localOutput, batches, numOutputRows, numInputBatches, omniColumnarToRowTime, mayPartialFetch) } } } @@ -310,27 +327,61 @@ case class OmniColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransiti object ColumnarBatchToInternalRow { def convert(output: Seq[Attribute], batches: Iterator[ColumnarBatch], - numOutputRows: SQLMetric, numInputBatches: SQLMetric ): Iterator[InternalRow] = { + numOutputRows: SQLMetric, numInputBatches: SQLMetric, + rowToOmniColumnarTime: SQLMetric, + mayPartialFetch: Boolean = true): Iterator[InternalRow] = { + val startTime = System.nanoTime() val toUnsafe = UnsafeProjection.create(output, output) - val vecsTmp = new ListBuffer[Vec] val batchIter = batches.flatMap { batch => - // store vec since tablescan reuse batch + + // toClosedVecs closed case: + // 1) all rows of batch fetched and closed + // 2) only fetch parital rows(eg: top-n, limit-n), closed at task CompletionListener callback + val toClosedVecs = new ListBuffer[Vec] for (i <- 0 until batch.numCols()) { batch.column(i) match { case vector: OmniColumnVector => - vecsTmp.append(vector.getVec) + toClosedVecs.append(vector.getVec) case _ => } } + numInputBatches += 1 - numOutputRows += batch.numRows() - batch.rowIterator().asScala.map(toUnsafe) - } + val iter = batch.rowIterator().asScala.map(toUnsafe) + rowToOmniColumnarTime += NANOSECONDS.toMillis(System.nanoTime() - startTime) + + new Iterator[InternalRow] { + val numOutputRowsMetric: SQLMetric = numOutputRows + var closed = false + + // only invoke if fetch partial rows of batch + if (mayPartialFetch) { + SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ => + if (!closed) { + toClosedVecs.foreach {vec => + vec.close() + } + } + } + } + + override def hasNext: Boolean = { + val has = iter.hasNext + // fetch all rows and closed + if (!has && !closed) { + toClosedVecs.foreach {vec => + vec.close() + } + closed = true + } + has + } - SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ => - vecsTmp.foreach {vec => - vec.close() + override def next(): InternalRow = { + numOutputRowsMetric += 1 + iter.next() + } } } batchIter diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala index e8c3e833fdaf10b7504229a089f063e555363dfe..0c3187461bcea187b17f1d03d9c871880e4de054 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -21,6 +21,7 @@ import java.util.Optional import java.util.concurrent.TimeUnit.NANOSECONDS import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor +import com.huawei.boostkit.spark.ColumnarPluginConfig import scala.collection.mutable.HashMap import scala.collection.JavaConverters._ @@ -49,6 +50,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.{OmniOrcFileFormat, OrcFileFormat} +import org.apache.spark.sql.execution.datasources.parquet.{OmniParquetFileFormat, ParquetFileFormat} import org.apache.spark.sql.execution.joins.ColumnarBroadcastHashJoinExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.util.SparkMemoryUtils @@ -59,8 +61,6 @@ import org.apache.spark.sql.types.{DecimalType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.collection.BitSet - - abstract class BaseColumnarFileSourceScanExec( @transient relation: HadoopFsRelation, output: Seq[Attribute], @@ -73,6 +73,10 @@ abstract class BaseColumnarFileSourceScanExec( disableBucketedScan: Boolean = false) extends DataSourceScanExec { + override val nodeName: String = { + s"OmniScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}" + } + override lazy val supportsColumnar: Boolean = true override def vectorTypes: Option[Seq[String]] = @@ -285,12 +289,24 @@ abstract class BaseColumnarFileSourceScanExec( |""".stripMargin } + val enableColumnarFileScan: Boolean = ColumnarPluginConfig.getSessionConf.enableColumnarFileScan + val enableOrcNativeFileScan: Boolean = ColumnarPluginConfig.getSessionConf.enableOrcNativeFileScan lazy val inputRDD: RDD[InternalRow] = { - val fileFormat: FileFormat = relation.fileFormat match { - case orcFormat: OrcFileFormat => - new OmniOrcFileFormat() - case _ => - throw new UnsupportedOperationException("Unsupported FileFormat!") + val fileFormat: FileFormat = if (enableColumnarFileScan) { + relation.fileFormat match { + case orcFormat: OrcFileFormat => + if (enableOrcNativeFileScan) { + new OmniOrcFileFormat() + } else { + relation.fileFormat + } + case parquetFormat: ParquetFileFormat => + new OmniParquetFileFormat() + case _ => + throw new UnsupportedOperationException("Unsupported FileFormat!") + } + } else { + relation.fileFormat } val readFile: (PartitionedFile) => Iterator[InternalRow] = fileFormat.buildReaderWithPartitionValues( @@ -382,6 +398,7 @@ abstract class BaseColumnarFileSourceScanExec( val numOutputRows = longMetric("numOutputRows") val scanTime = longMetric("scanTime") val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val localSchema = this.schema inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => new Iterator[ColumnarBatch] { @@ -395,9 +412,16 @@ abstract class BaseColumnarFileSourceScanExec( override def next(): ColumnarBatch = { val batch = batches.next() - numOutputRows += batch.numRows() + val input = transColBatchToOmniVecs(batch) + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + batch.numRows, localSchema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(input(i)) + } + numOutputRows += batch.numRows numOutputVecBatchs += 1 - batch + new ColumnarBatch(vectors.toArray, batch.numRows) } } } @@ -528,9 +552,14 @@ abstract class BaseColumnarFileSourceScanExec( val omniAggFunctionTypes = new Array[FunctionType](agg.aggregateExpressions.size) val omniAggOutputTypes = new Array[Array[DataType]](agg.aggregateExpressions.size) val omniAggChannels = new Array[Array[String]](agg.aggregateExpressions.size) + val omniAggChannelsFilter = new Array[String](agg.aggregateExpressions.size) var omniAggindex = 0 for (exp <- agg.aggregateExpressions) { + if (exp.filter.isDefined) { + omniAggChannelsFilter(omniAggindex) = + rewriteToOmniJsonExpressionLiteral(exp.filter.get, attrAggExpsIdMap) + } if (exp.mode == Final) { throw new UnsupportedOperationException(s"Unsupported final aggregate expression in operator fusion, exp: $exp") } else if (exp.mode == Partial) { @@ -588,8 +617,8 @@ abstract class BaseColumnarFileSourceScanExec( case (attr, i) => omniAggSourceTypes(i) = sparkTypeToOmniType(attr.dataType, attr.metadata) } - (omniGroupByChanel, omniAggChannels, omniAggSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, - omniAggInputRaws, omniAggOutputPartials, resultIdxToOmniResultIdxMap) + (omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, omniAggSourceTypes, omniAggFunctionTypes, + omniAggOutputTypes, omniAggInputRaws, omniAggOutputPartials, resultIdxToOmniResultIdxMap) } def genProjectOutput(project: ColumnarProjectExec) = { @@ -818,8 +847,8 @@ case class ColumnarMultipleOperatorExec( val omniCodegenTime = longMetric("omniJitTime") val getOutputTime = longMetric("outputTime") - val (omniGroupByChanel, omniAggChannels, omniAggSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, - omniAggInputRaw, omniAggOutputPartial, resultIdxToOmniResultIdxMap) = genAggOutput(aggregate) + val (omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, omniAggSourceTypes, omniAggFunctionTypes, + omniAggOutputTypes, omniAggInputRaw, omniAggOutputPartial, resultIdxToOmniResultIdxMap) = genAggOutput(aggregate) val (proj1OmniExpressions, proj1OmniInputTypes) = genProjectOutput(proj1) val (buildTypes1, buildJoinColsExp1, joinFilter1, probeTypes1, probeOutputCols1, probeHashColsExp1, buildOutputCols1, buildOutputTypes1, relation1) = genJoinOutput(join1) @@ -838,16 +867,15 @@ case class ColumnarMultipleOperatorExec( // for join val deserializer = VecBatchSerializerFactory.create() val startCodegen = System.nanoTime() - val aggFactory = new OmniHashAggregationWithExprOperatorFactory( + val aggOperator = OmniAdaptorUtil.getAggOperator(aggregate.groupingExpressions, omniGroupByChanel, omniAggChannels, + omniAggChannelsFilter, omniAggSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, omniAggInputRaw, - omniAggOutputPartial, - new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) - val aggOperator = aggFactory.createOperator + omniAggOutputPartial) omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { aggOperator.close() @@ -862,7 +890,7 @@ case class ColumnarMultipleOperatorExec( }) val buildOpFactory1 = new OmniHashBuilderWithExprOperatorFactory(buildTypes1, - buildJoinColsExp1, if (joinFilter1.nonEmpty) {Optional.of(joinFilter1.get)} else {Optional.empty()}, 1, + buildJoinColsExp1, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp1 = buildOpFactory1.createOperator() @@ -878,6 +906,7 @@ case class ColumnarMultipleOperatorExec( buildOp1.getOutput val lookupOpFactory1 = new OmniLookupJoinWithExprOperatorFactory(probeTypes1, probeOutputCols1, probeHashColsExp1, buildOutputCols1, buildOutputTypes1, OMNI_JOIN_TYPE_INNER, buildOpFactory1, + if (joinFilter1.nonEmpty) {Optional.of(joinFilter1.get)} else {Optional.empty()}, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp1 = lookupOpFactory1.createOperator() // close operator @@ -895,7 +924,7 @@ case class ColumnarMultipleOperatorExec( }) val buildOpFactory2 = new OmniHashBuilderWithExprOperatorFactory(buildTypes2, - buildJoinColsExp2, if (joinFilter2.nonEmpty) {Optional.of(joinFilter2.get)} else {Optional.empty()}, 1, + buildJoinColsExp2, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp2 = buildOpFactory2.createOperator() @@ -911,6 +940,7 @@ case class ColumnarMultipleOperatorExec( buildOp2.getOutput val lookupOpFactory2 = new OmniLookupJoinWithExprOperatorFactory(probeTypes2, probeOutputCols2, probeHashColsExp2, buildOutputCols2, buildOutputTypes2, OMNI_JOIN_TYPE_INNER, buildOpFactory2, + if (joinFilter2.nonEmpty) {Optional.of(joinFilter2.get)} else {Optional.empty()}, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp2 = lookupOpFactory2.createOperator() @@ -929,7 +959,7 @@ case class ColumnarMultipleOperatorExec( }) val buildOpFactory3 = new OmniHashBuilderWithExprOperatorFactory(buildTypes3, - buildJoinColsExp3, if (joinFilter3.nonEmpty) {Optional.of(joinFilter3.get)} else {Optional.empty()}, 1, + buildJoinColsExp3, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp3 = buildOpFactory3.createOperator() @@ -945,6 +975,7 @@ case class ColumnarMultipleOperatorExec( buildOp3.getOutput val lookupOpFactory3 = new OmniLookupJoinWithExprOperatorFactory(probeTypes3, probeOutputCols3, probeHashColsExp3, buildOutputCols3, buildOutputTypes3, OMNI_JOIN_TYPE_INNER, buildOpFactory3, + if (joinFilter3.nonEmpty) {Optional.of(joinFilter3.get)} else {Optional.empty()}, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp3 = lookupOpFactory3.createOperator() @@ -963,7 +994,7 @@ case class ColumnarMultipleOperatorExec( }) val buildOpFactory4 = new OmniHashBuilderWithExprOperatorFactory(buildTypes4, - buildJoinColsExp4, if (joinFilter4.nonEmpty) {Optional.of(joinFilter4.get)} else {Optional.empty()}, 1, + buildJoinColsExp4, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp4 = buildOpFactory4.createOperator() @@ -979,6 +1010,7 @@ case class ColumnarMultipleOperatorExec( buildOp4.getOutput val lookupOpFactory4 = new OmniLookupJoinWithExprOperatorFactory(probeTypes4, probeOutputCols4, probeHashColsExp4, buildOutputCols4, buildOutputTypes4, OMNI_JOIN_TYPE_INNER, buildOpFactory4, + if (joinFilter4.nonEmpty) {Optional.of(joinFilter4.get)} else {Optional.empty()}, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp4 = lookupOpFactory4.createOperator() @@ -1167,8 +1199,8 @@ case class ColumnarMultipleOperatorExec1( val omniCodegenTime = longMetric("omniJitTime") val getOutputTime = longMetric("outputTime") - val (omniGroupByChanel, omniAggChannels, omniAggSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, - omniAggInputRaw, omniAggOutputPartial, resultIdxToOmniResultIdxMap) = genAggOutput(aggregate) + val (omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, omniAggSourceTypes, omniAggFunctionTypes, + omniAggOutputTypes, omniAggInputRaw, omniAggOutputPartial, resultIdxToOmniResultIdxMap) = genAggOutput(aggregate) val (proj1OmniExpressions, proj1OmniInputTypes) = genProjectOutput(proj1) val (buildTypes1, buildJoinColsExp1, joinFilter1, probeTypes1, probeOutputCols1, probeHashColsExp1, buildOutputCols1, buildOutputTypes1, relation1, reserved1) = genJoinOutputWithReverse(join1) @@ -1200,16 +1232,15 @@ case class ColumnarMultipleOperatorExec1( // for join val deserializer = VecBatchSerializerFactory.create() val startCodegen = System.nanoTime() - val aggFactory = new OmniHashAggregationWithExprOperatorFactory( + val aggOperator = OmniAdaptorUtil.getAggOperator(aggregate.groupingExpressions, omniGroupByChanel, omniAggChannels, + omniAggChannelsFilter, omniAggSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, omniAggInputRaw, - omniAggOutputPartial, - new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) - val aggOperator = aggFactory.createOperator + omniAggOutputPartial) omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { aggOperator.close() @@ -1224,7 +1255,7 @@ case class ColumnarMultipleOperatorExec1( }) val buildOpFactory1 = new OmniHashBuilderWithExprOperatorFactory(buildTypes1, - buildJoinColsExp1, if (joinFilter1.nonEmpty) {Optional.of(joinFilter1.get)} else {Optional.empty()}, 1, + buildJoinColsExp1, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp1 = buildOpFactory1.createOperator() @@ -1240,6 +1271,7 @@ case class ColumnarMultipleOperatorExec1( buildOp1.getOutput val lookupOpFactory1 = new OmniLookupJoinWithExprOperatorFactory(probeTypes1, probeOutputCols1, probeHashColsExp1, buildOutputCols1, buildOutputTypes1, OMNI_JOIN_TYPE_INNER, buildOpFactory1, + if (joinFilter1.nonEmpty) {Optional.of(joinFilter1.get)} else {Optional.empty()}, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp1 = lookupOpFactory1.createOperator() @@ -1258,7 +1290,7 @@ case class ColumnarMultipleOperatorExec1( }) val buildOpFactory2 = new OmniHashBuilderWithExprOperatorFactory(buildTypes2, - buildJoinColsExp2, if (joinFilter2.nonEmpty) {Optional.of(joinFilter2.get)} else {Optional.empty()}, 1, + buildJoinColsExp2, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp2 = buildOpFactory2.createOperator() @@ -1274,6 +1306,7 @@ case class ColumnarMultipleOperatorExec1( buildOp2.getOutput val lookupOpFactory2 = new OmniLookupJoinWithExprOperatorFactory(probeTypes2, probeOutputCols2, probeHashColsExp2, buildOutputCols2, buildOutputTypes2, OMNI_JOIN_TYPE_INNER, buildOpFactory2, + if (joinFilter2.nonEmpty) {Optional.of(joinFilter2.get)} else {Optional.empty()}, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp2 = lookupOpFactory2.createOperator() @@ -1292,7 +1325,7 @@ case class ColumnarMultipleOperatorExec1( }) val buildOpFactory3 = new OmniHashBuilderWithExprOperatorFactory(buildTypes3, - buildJoinColsExp3, if (joinFilter3.nonEmpty) {Optional.of(joinFilter3.get)} else {Optional.empty()}, 1, + buildJoinColsExp3, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp3 = buildOpFactory3.createOperator() @@ -1308,6 +1341,7 @@ case class ColumnarMultipleOperatorExec1( buildOp3.getOutput val lookupOpFactory3 = new OmniLookupJoinWithExprOperatorFactory(probeTypes3, probeOutputCols3, probeHashColsExp3, buildOutputCols3, buildOutputTypes3, OMNI_JOIN_TYPE_INNER, buildOpFactory3, + if (joinFilter3.nonEmpty) {Optional.of(joinFilter3.get)} else {Optional.empty()}, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp3 = lookupOpFactory3.createOperator() diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala index 4414c3756d58ea9afd6e5f25e6c57233a7fb37b7..bd9cceaf573158a68add424780c6922d35ee1e72 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala @@ -54,6 +54,10 @@ case class ColumnarHashAggregateExec( extends BaseAggregateExec with AliasAwareOutputPartitioning { + override lazy val allAttributes: AttributeSeq = + child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + override def verboseStringWithOperatorId(): String = { s""" |$formattedNodeName @@ -92,10 +96,12 @@ case class ColumnarHashAggregateExec( val omniAggFunctionTypes = new Array[FunctionType](aggregateExpressions.size) val omniAggOutputTypes = new Array[Array[DataType]](aggregateExpressions.size) var omniAggChannels = new Array[Array[String]](aggregateExpressions.size) + val omniAggChannelsFilter = new Array[String](aggregateExpressions.size) var index = 0 for (exp <- aggregateExpressions) { if (exp.filter.isDefined) { - throw new UnsupportedOperationException("Unsupported filter in AggregateExpression") + omniAggChannelsFilter(index) = + rewriteToOmniJsonExpressionLiteral(exp.filter.get, attrExpsIdMap) } if (exp.mode == Final) { exp.aggregateFunction match { @@ -111,16 +117,13 @@ case class ColumnarHashAggregateExec( } else if (exp.mode == PartialMerge) { exp.aggregateFunction match { case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) => - omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) + omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) omniAggChannels(index) = toOmniAggInOutJSonExp(exp.aggregateFunction.inputAggBufferAttributes, attrExpsIdMap) omniInputRaws(index) = false omniOutputPartials(index) = true - if (omniAggFunctionTypes(index) == OMNI_AGGREGATION_TYPE_COUNT_ALL) { - omniAggChannels(index) = null - } case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: ${exp}") } } else if (exp.mode == Partial) { @@ -160,6 +163,12 @@ case class ColumnarHashAggregateExec( checkOmniJsonWhiteList("", omniGroupByChanel) } + for (filter <- omniAggChannelsFilter) { + if (filter != null && !isSimpleColumn(filter)) { + checkOmniJsonWhiteList(filter, new Array[AnyRef](0)) + } + } + // final steps contail all Final mode aggregate if (aggregateExpressions.filter(_.mode == Final).size == aggregateExpressions.size) { val finalOut = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes @@ -191,6 +200,7 @@ case class ColumnarHashAggregateExec( val omniAggFunctionTypes = new Array[FunctionType](aggregateExpressions.size) val omniAggOutputTypes = new Array[Array[DataType]](aggregateExpressions.size) var omniAggChannels = new Array[Array[String]](aggregateExpressions.size) + val omniAggChannelsFilter = new Array[String](aggregateExpressions.size) val finalStep = (aggregateExpressions.filter (_.mode == Final).size == aggregateExpressions.size) @@ -198,7 +208,8 @@ case class ColumnarHashAggregateExec( var index = 0 for (exp <- aggregateExpressions) { if (exp.filter.isDefined) { - throw new UnsupportedOperationException("Unsupported filter in AggregateExpression") + omniAggChannelsFilter(index) = + rewriteToOmniJsonExpressionLiteral(exp.filter.get, attrExpsIdMap) } if (exp.mode == Final) { exp.aggregateFunction match { @@ -215,16 +226,13 @@ case class ColumnarHashAggregateExec( } else if (exp.mode == PartialMerge) { exp.aggregateFunction match { case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) => - omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) + omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) omniAggChannels(index) = toOmniAggInOutJSonExp(exp.aggregateFunction.inputAggBufferAttributes, attrExpsIdMap) omniInputRaws(index) = false omniOutputPartials(index) = true - if (omniAggFunctionTypes(index) == OMNI_AGGREGATION_TYPE_COUNT_ALL) { - omniAggChannels(index) = null - } case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: ${exp}") } } else if (exp.mode == Partial) { @@ -257,16 +265,15 @@ case class ColumnarHashAggregateExec( child.executeColumnar().mapPartitionsWithIndex { (index, iter) => val startCodegen = System.nanoTime() - val factory = new OmniHashAggregationWithExprOperatorFactory( + val operator = OmniAdaptorUtil.getAggOperator(groupingExpressions, omniGroupByChanel, omniAggChannels, + omniAggChannelsFilter, omniSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, omniInputRaws, - omniOutputPartials, - new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) - val operator = factory.createOperator + omniOutputPartials) omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) // close operator diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index b13f5aca1c20db245c479043c50e2b1c7951d609..3638f865f1b8e7450dfc1cfea07a75bd8f5fb4ae 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.createShuffle import org.apache.spark.sql.execution.metric._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.execution.util.MergeIterator +import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -73,6 +74,7 @@ class ColumnarShuffleExchangeExec( .createAverageMetric(sparkContext, "avg read batch num rows"), "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs"), + "bypassVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of bypass vecBatchs"), "numOutputRows" -> SQLMetrics .createMetric(sparkContext, "number of output rows")) ++ readMetrics ++ writeMetrics @@ -148,7 +150,8 @@ class ColumnarShuffleExchangeExec( cachedShuffleRDD.mapPartitionsWithIndexInternal { (index, iter) => new MergeIterator(iter, StructType.fromAttributes(child.output), - longMetric("numMergedVecBatchs")) + longMetric("numMergedVecBatchs"), + longMetric("bypassVecBatchs")) } } else { cachedShuffleRDD @@ -158,6 +161,7 @@ class ColumnarShuffleExchangeExec( object ColumnarShuffleExchangeExec extends Logging { val defaultMm3HashSeed: Int = 42; + val rollupConst : String = "spark_grouping_id" def prepareShuffleDependency( rdd: RDD[ColumnarBatch], @@ -173,7 +177,7 @@ object ColumnarShuffleExchangeExec extends Logging { ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { - val rangePartitioner: Option[Partitioner] = newPartitioning match { + val part: Option[Partitioner] = newPartitioning match { case RangePartitioning(sortingExpressions, numPartitions) => // Extract only fields used for sorting to avoid collecting large fields that does not // affect sorting result when deciding partition bounds in RangePartitioner @@ -213,6 +217,8 @@ object ColumnarShuffleExchangeExec extends Logging { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) Some(part) + case HashPartitioning(_, n) => + Some(new PartitionIdPassthrough(n)) case _ => None } @@ -244,8 +250,7 @@ object ColumnarShuffleExchangeExec extends Logging { (0, new ColumnarBatch(newColumns, cb.numRows)) } - // only used for fallback range partitioning - def computeAndAddRangePartitionId( + def computePartitionId( cbIter: Iterator[ColumnarBatch], partitionKeyExtractor: InternalRow => Any): Iterator[(Int, ColumnarBatch)] = { val addPid2ColumnBatch = addPidToColumnBatch() @@ -254,7 +259,7 @@ object ColumnarShuffleExchangeExec extends Logging { val pidArr = new Array[Int](cb.numRows) (0 until cb.numRows).foreach { i => val row = cb.getRow(i) - val pid = rangePartitioner.get.getPartition(partitionKeyExtractor(row)) + val pid = part.get.getPartition(partitionKeyExtractor(row)) pidArr(i) = pid } val pidVec = new IntVec(cb.numRows) @@ -268,6 +273,13 @@ object ColumnarShuffleExchangeExec extends Logging { newPartitioning.numPartitions > 1 val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition + def containsRollUp(expressions: Seq[Expression]) : Boolean = { + expressions.exists{ + case attr: AttributeReference if rollupConst.equals(attr.name) => true + case _ => false + } + } + val rddWithPartitionId: RDD[Product2[Int, ColumnarBatch]] = newPartitioning match { case RoundRobinPartitioning(numPartitions) => // 按随机数分区 @@ -287,34 +299,50 @@ object ColumnarShuffleExchangeExec extends Logging { UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) } - val newIter = computeAndAddRangePartitionId(cbIter, partitionKeyExtractor) + val newIter = computePartitionId(cbIter, partitionKeyExtractor) newIter }, isOrderSensitive = isOrderSensitive) - case HashPartitioning(expressions, numPartitions) => - rdd.mapPartitionsWithIndexInternal((_, cbIter) => { - val addPid2ColumnBatch = addPidToColumnBatch() - // omni project - val genHashExpression = genHashExpr() - val omniExpr: String = genHashExpression(expressions, numPartitions, defaultMm3HashSeed, outputAttributes) - val factory = new OmniProjectOperatorFactory(Array(omniExpr), inputTypes, 1, - new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) - val op = factory.createOperator() - - cbIter.map { cb => - val vecs = transColBatchToOmniVecs(cb, true) - op.addInput(new VecBatch(vecs, cb.numRows())) - val res = op.getOutput - if (res.hasNext) { - val retBatch = res.next() - val pidVec = retBatch.getVectors()(0) - // close return VecBatch - retBatch.close() - addPid2ColumnBatch(pidVec.asInstanceOf[IntVec], cb) - } else { - throw new Exception("Empty Project Operator Result...") + case h@HashPartitioning(expressions, numPartitions) => + if (containsRollUp(expressions) || expressions.length > 6) { + rdd.mapPartitionsWithIndexInternal((_, cbIter) => { + val partitionKeyExtractor: InternalRow => Any = { + val projection = + UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) + row => projection(row).getInt(0) } - } - }, isOrderSensitive = isOrderSensitive) + val newIter = computePartitionId(cbIter, partitionKeyExtractor) + newIter + }, isOrderSensitive = isOrderSensitive) + } else { + rdd.mapPartitionsWithIndexInternal((_, cbIter) => { + val addPid2ColumnBatch = addPidToColumnBatch() + // omni project + val genHashExpression = genHashExpr() + val omniExpr: String = genHashExpression(expressions, numPartitions, defaultMm3HashSeed, outputAttributes) + val factory = new OmniProjectOperatorFactory(Array(omniExpr), inputTypes, 1, + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + val op = factory.createOperator() + // close operator + addLeakSafeTaskCompletionListener[Unit](_ => { + op.close() + }) + + cbIter.map { cb => + val vecs = transColBatchToOmniVecs(cb, true) + op.addInput(new VecBatch(vecs, cb.numRows())) + val res = op.getOutput + if (res.hasNext) { + val retBatch = res.next() + val pidVec = retBatch.getVectors()(0) + // close return VecBatch + retBatch.close() + addPid2ColumnBatch(pidVec.asInstanceOf[IntVec], cb) + } else { + throw new Exception("Empty Project Operator Result...") + } + } + }, isOrderSensitive = isOrderSensitive) + } case SinglePartition => rdd.mapPartitionsWithIndexInternal((_, cbIter) => { cbIter.map { cb => (0, cb) } @@ -373,4 +401,4 @@ object ColumnarShuffleExchangeExec extends Logging { } } -} \ No newline at end of file +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala index 7c7001dbc1c468465a0115946aeff9849d51a3df..24ccbccaf61bca83c49f266b746b407bd401441c 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala @@ -110,6 +110,7 @@ case class ColumnarSortExec( child.executeColumnar().mapPartitionsWithIndexInternal { (_, iter) => val columnarConf = ColumnarPluginConfig.getSessionConf val sortSpillRowThreshold = columnarConf.columnarSortSpillRowThreshold + val sortSpillMemPctThreshold = columnarConf.columnarSortSpillMemPctThreshold val sortSpillDirDiskReserveSize = columnarConf.columnarSortSpillDirDiskReserveSize val sortSpillEnable = columnarConf.enableSortSpill val sortlocalDirs: Array[File] = generateLocalDirs(sparkConfTmp) @@ -117,7 +118,7 @@ case class ColumnarSortExec( val dirId = hash % sortlocalDirs.length val spillPathDir = sortlocalDirs(dirId).getCanonicalPath val sparkSpillConf = new SparkSpillConfig(sortSpillEnable, spillPathDir, - sortSpillDirDiskReserveSize, sortSpillRowThreshold) + sortSpillDirDiskReserveSize, sortSpillRowThreshold, sortSpillMemPctThreshold) val startCodegen = System.nanoTime() val sortOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, outputCols, sortColsExp, ascendings, nullFirsts, new OperatorConfig(sparkSpillConf, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTopNSortExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTopNSortExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..0b33aa0c4a47205a6ef82a6ca7f4ca9743a7e0a2 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTopNSortExec.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.concurrent.TimeUnit.NANOSECONDS +import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ +import com.huawei.boostkit.spark.util.OmniAdaptorUtil +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{addAllAndGetIterator, genSortParam} +import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} +import nova.hetu.omniruntime.operator.topnsort.OmniTopNSortWithExprOperatorFactory +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.vectorized.ColumnarBatch + +case class ColumnarTopNSortExec( + n: Int, + strictTopN: Boolean, + partitionSpec: Seq[Expression], + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) + extends UnaryExecNode { + + + override def supportsColumnar: Boolean = true + + override def nodeName: String = "OmniColumnarTopNSort" + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder + + override def outputPartitioning: Partitioning = child.outputPartitioning + + protected def withNewChildInternal(newChild: SparkPlan): ColumnarTopNSortExec = + copy(child = newChild) + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + override lazy val metrics = Map( + + "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), + "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), + "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "outputDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "output data size"), + "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) + + def buildCheck(): Unit = { + // current only support rank function of window + // strictTopN true for row_number, false for rank + if (strictTopN) { + throw new UnsupportedOperationException(s"Unsupported strictTopN is true") + } + val omniAttrExpsIdMap = getExprIdMap(child.output) + val omniPartitionChanels: Array[AnyRef] = partitionSpec.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray + checkOmniJsonWhiteList("", omniPartitionChanels) + genSortParam(child.output, sortOrder) + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val omniCodegenTime = longMetric("omniCodegenTime") + val omniAttrExpsIdMap = getExprIdMap(child.output) + val omniPartitionChanels = partitionSpec.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray + val (sourceTypes, ascendings, nullFirsts, sortColsExp) = genSortParam(child.output, sortOrder) + + child.executeColumnar().mapPartitionsWithIndexInternal { (_, iter) => + val startCodegen = System.nanoTime() + val topNSortOperatorFactory = new OmniTopNSortWithExprOperatorFactory(sourceTypes, n, + strictTopN, omniPartitionChanels, sortColsExp, ascendings, nullFirsts, + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + val topNSortOperator = topNSortOperatorFactory.createOperator + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + topNSortOperator.close() + }) + addAllAndGetIterator(topNSortOperator, iter, this.schema, + longMetric("addInputTime"), longMetric("numInputVecBatchs"), longMetric("numInputRows"), + longMetric("getOutputTime"), longMetric("numOutputVecBatchs"), longMetric("numOutputRows"), + longMetric("outputDataSize")) + } + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException(s"This operator doesn't support doExecute().") + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala index e5534d3c67680a747f1b4d92fcb2c377c81577c9..b44c78803258ca39565ea8a5e5e1f523e478987b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala @@ -217,7 +217,7 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], val winExpressions: Seq[Expression] = windowFrameExpressionFactoryPairs.flatMap(_._1) val windowFunType = new Array[FunctionType](winExpressions.size) val omminPartitionChannels = new Array[Int](partitionSpec.size) - val preGroupedChannels = new Array[Int](winExpressions.size) + val preGroupedChannels = new Array[Int](0) var windowArgKeys = new Array[String](winExpressions.size) val windowFunRetType = new Array[DataType](winExpressions.size) val omniAttrExpsIdMap = getExprIdMap(child.output) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala index d34b93e5b0da5b61ac35c0824acbf817f1a5e938..a055572ceb083d7912e323c7ce13ec3a0db895f2 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.adaptive +import com.huawei.boostkit.spark.ColumnarPluginConfig import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} @@ -24,6 +25,8 @@ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartit import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.util.MergeIterator +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import scala.collection.mutable.ArrayBuffer @@ -151,9 +154,12 @@ case class ColumnarCustomShuffleReaderExec( SQLMetrics.postDriverMetricsUpdatedByValue(sparkContext, executionId, driverAccumUpdates.toSeq) } - @transient override lazy val metrics: Map[String, SQLMetric] = { + override lazy val metrics: Map[String, SQLMetric] = { if (shuffleStage.isDefined) { - Map("numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions")) ++ { + Map( + "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs"), + "bypassVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of bypass vecBatchs"), + "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions")) ++ { if (isLocalReader) { // We split the mapper partition evenly when creating local shuffle reader, so no // data size info is available. @@ -195,7 +201,18 @@ case class ColumnarCustomShuffleReaderExec( throw new IllegalStateException("operating on canonicalized plan") } } - cachedShuffleRDD + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val enableShuffleBatchMerge: Boolean = columnarConf.enableShuffleBatchMerge + if (enableShuffleBatchMerge) { + cachedShuffleRDD.mapPartitionsWithIndexInternal { (index, iter) => + new MergeIterator(iter, + StructType.fromAttributes(child.output), + longMetric("numMergedVecBatchs"), + longMetric("bypassVecBatchs")) + } + } else { + cachedShuffleRDD + } } override protected def doExecute(): RDD[InternalRow] = { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/ExtendedAggUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/ExtendedAggUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..c8ec22e0bcca70622020f232ec421e53b24d8a1a --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/ExtendedAggUtils.scala @@ -0,0 +1,99 @@ +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Partial} +import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.execution.SparkPlan + +object ExtendedAggUtils { + def normalizeGroupingExpressions(groupingExpressions: Seq[NamedExpression]) = { + groupingExpressions.map { e => + NormalizeFloatingNumbers.normalize(e) match { + case n: NamedExpression => n + case other => Alias(other, e.name)(exprId = e.exprId) + } + } + } + + def planPartialAggregateWithoutDistinct( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): SparkPlan = { + val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) + createAggregate( + requiredChildDistributionExpressions = None, + groupingExpressions = groupingExpressions.map(_.toAttribute), + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute), + initialInputBufferOffset = groupingExpressions.length, + resultExpressions = resultExpressions, + child = child) + } + + private def createAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]] = None, + isStreaming: Boolean = false, + groupingExpressions: Seq[NamedExpression] = Nil, + aggregateExpressions: Seq[AggregateExpression] = Nil, + aggregateAttributes: Seq[Attribute] = Nil, + initialInputBufferOffset: Int = 0, + resultExpressions: Seq[NamedExpression] = Nil, + child: SparkPlan): SparkPlan = { + val useHash = HashAggregateExec.supportsAggregate( + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) + if (useHash) { + HashAggregateExec( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } else { + val objectHashEnabled = child.sqlContext.conf.useObjectHashAggregation + val useObjectHash = ObjectHashAggregateExec.supportsAggregate(aggregateExpressions) + + if (objectHashEnabled && useObjectHash) { + ObjectHashAggregateExec( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } else { + SortAggregateExec( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } + } + } + + private def mayRemoveAggFilters(exprs: Seq[AggregateExpression]): Seq[AggregateExpression] = { + exprs.map { ae => + if (ae.filter.isDefined) { + ae.mode match { + case Partial | Complete => ae + case _ => ae.copy(filter = None) + } + } else { + ae + } + } + } +} + +case class DummyLogicalPlan() extends LeafNode { + override def output: Seq[Attribute] = Nil + + override def computeStats(): Statistics = throw new UnsupportedOperationException +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala new file mode 100644 index 0000000000000000000000000000000000000000..c9a0dcbbfb3f7ff4dda3d60e3b91dfbabcb05c5b --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.catalog.CatalogStatistics +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, FileScan} +import org.apache.spark.sql.types.StructType + +/** + * Prune the partitions of file source based table using partition filters. Currently, this rule + * is applied to [[HadoopFsRelation]] with [[CatalogFileIndex]] and [[DataSourceV2ScanRelation]] + * with [[FileScan]]. + * + * For [[HadoopFsRelation]], the location will be replaced by pruned file index, and corresponding + * statistics will be updated. And the partition filters will be kept in the filters of returned + * logical plan. + * + * For [[DataSourceV2ScanRelation]], both partition filters and data filters will be added to + * its underlying [[FileScan]]. And the partition filters will be removed in the filters of + * returned logical plan. + */ +private[sql] object PruneFileSourcePartitions + extends Rule[LogicalPlan] with PredicateHelper { + + private def getPartitionKeyFiltersAndDataFilters( + sparkSession: SparkSession, + relation: LeafNode, + partitionSchema: StructType, + filters: Seq[Expression], + output: Seq[AttributeReference]): (ExpressionSet, Seq[Expression]) = { + val normalizedFilters = DataSourceStrategy.normalizeExprs( + filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), output) + val partitionColumns = + relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver) + val partitionSet = AttributeSet(partitionColumns) + val (partitionFilters, dataFilters) = normalizedFilters.partition(f => + f.references.subsetOf(partitionSet) + ) + val extraPartitionFilter = + dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet)) + + (ExpressionSet(partitionFilters ++ extraPartitionFilter), dataFilters) + } + + private def rebuildPhysicalOperation( + projects: Seq[NamedExpression], + filters: Seq[Expression], + relation: LeafNode): Project = { + val withFilter = if (filters.nonEmpty) { + val filterExpression = filters.reduceLeft(And) + Filter(filterExpression, relation) + } else { + relation + } + Project(projects, withFilter) + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case op @ PhysicalOperation(projects, filters, + logicalRelation @ + LogicalRelation(fsRelation @ + HadoopFsRelation( + catalogFileIndex: CatalogFileIndex, + partitionSchema, + _, + _, + _, + _), + _, + _, + _)) + if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => + val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters( + fsRelation.sparkSession, logicalRelation, partitionSchema, filters, + logicalRelation.output) + // Fix spark issue SPARK-34119(row 104-113) + if (partitionKeyFilters.nonEmpty) { + val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) + val prunedFsRelation = + fsRelation.copy(location = prunedFileIndex)(fsRelation.sparkSession) + // Change table stats based on the sizeInBytes of pruned files + val filteredStats = + FilterEstimation(Filter(partitionKeyFilters.reduce(And), logicalRelation)).estimate + val colStats = filteredStats.map(_.attributeStats.map { case (attr, colStat) => + (attr.name, colStat.toCatalogColumnStat(attr.name, attr.dataType)) + }) + val withStats = logicalRelation.catalogTable.map(_.copy( + stats = Some(CatalogStatistics( + sizeInBytes = BigInt(prunedFileIndex.sizeInBytes), + rowCount = filteredStats.flatMap(_.rowCount), + colStats = colStats.getOrElse(Map.empty))))) + val prunedLogicalRelation = logicalRelation.copy( + relation = prunedFsRelation, catalogTable = withStats) + // Keep partition-pruning predicates so that they are visible in physical planning + rebuildPhysicalOperation(projects, filters, prunedLogicalRelation) + } else { + op + } + + case op @ PhysicalOperation(projects, filters, + v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output)) + if filters.nonEmpty && scan.readDataSchema.nonEmpty => + val (partitionKeyFilters, dataFilters) = + getPartitionKeyFiltersAndDataFilters(scan.sparkSession, v2Relation, + scan.readPartitionSchema, filters, output) + // The dataFilters are pushed down only once + if (partitionKeyFilters.nonEmpty || (dataFilters.nonEmpty && scan.dataFilters.isEmpty)) { + val prunedV2Relation = + v2Relation.copy(scan = scan.withFilters(partitionKeyFilters.toSeq, dataFilters)) + // The pushed down partition filters don't need to be reevaluated. + val afterScanFilters = + ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty) + rebuildPhysicalOperation(projects, afterScanFilters.toSeq, prunedV2Relation) + } else { + op + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index b64fe9c7eaf322a7219170e89358515c0b9f8346..3392caa54f0cff52820b98196f9cbd0235151ef3 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -44,7 +44,9 @@ object OrcUtils extends Logging { "NONE" -> "", "SNAPPY" -> ".snappy", "ZLIB" -> ".zlib", - "LZO" -> ".lzo") + "LZO" -> ".lzo", + "ZSTD" -> ".zstd", + "ZSTD_JNI" -> ".zstd_jni") def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala new file mode 100644 index 0000000000000000000000000000000000000000..ff5af85d0943c7356987e2796bd0f0056fea3145 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala @@ -0,0 +1,178 @@ +/* + * Copyright (C) 2021-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import com.huawei.boostkit.spark.ColumnarPluginConfig +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER +import org.apache.parquet.hadoop._ +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableConfiguration + +import java.net.URI + +class OmniParquetFileFormat extends FileFormat with DataSourceRegister with Logging with Serializable { + + override def shortName(): String = "parquet-native" + + override def toString: String = "PARQUET-NATIVE" + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[OmniParquetFileFormat] + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + throw new UnsupportedOperationException() + } + + override def inferSchema( + sparkSession: SparkSession, + parameters: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + ParquetUtils.inferSchema(sparkSession, parameters, files) + } + + override def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + // Prepare hadoopConf + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) + hadoopConf.set( + ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, + requiredSchema.json) + hadoopConf.set( + SQLConf.SESSION_LOCAL_TIMEZONE.key, + sparkSession.sessionState.conf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + sparkSession.sessionState.conf.nestedSchemaPruningEnabled) + hadoopConf.setBoolean( + SQLConf.CASE_SENSITIVE.key, + sparkSession.sessionState.conf.caseSensitiveAnalysis) + + // Sets flags for `ParquetToSparkSchemaConverter` + hadoopConf.setBoolean( + SQLConf.PARQUET_BINARY_AS_STRING.key, + sparkSession.sessionState.conf.isParquetBinaryAsString) + hadoopConf.setBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sparkSession.sessionState.conf.isParquetINT96AsTimestamp) + + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + val sqlConf = sparkSession.sessionState.conf + + val capacity = sqlConf.parquetVectorizedReaderBatchSize + + val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown + val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + val isCaseSensitive = sqlConf.caseSensitiveAnalysis + + (file: PartitionedFile) => { + assert(file.partitionValues.numFields == partitionSchema.size) + + val filePath = new Path(new URI(file.filePath)) + val split = + new org.apache.parquet.hadoop.ParquetInputSplit( + filePath, + file.start, + file.start + file.length, + file.length, + Array.empty, + null) + + val sharedConf = broadcastedHadoopConf.value.value + + val fileFooter = ParquetFileReader.readFooter(sharedConf, filePath, NO_FILTER) + + val footerFileMetaData = fileFooter.getFileMetaData + + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringStartWith, + pushDownInFilterThreshold, + isCaseSensitive) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter(_)) + .reduceOption(FilterApi.and) + } else { + None + } + + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val hadoopAttemptContext = + new TaskAttemptContextImpl(broadcastedHadoopConf.value.value, attemptId) + + // Try to push down filters when filter push-down is enabled. + // Notice: This push-down is RowGroups level, not individual records. + if (pushed.isDefined) { + ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) + } + + val batchReader = new OmniParquetColumnarBatchReader(capacity, fileFooter) + + val iter = new RecordReaderIterator(batchReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) + SparkMemoryUtils.init() + + batchReader.initialize(split, hadoopAttemptContext) + logDebug(s"Appending $partitionSchema ${file.partitionValues}") + batchReader.initBatch(partitionSchema, file.partitionValues) + + // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. + iter.asInstanceOf[Iterator[InternalRow]] + } + } + +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/BaseColumnarSortMergeJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/BaseColumnarSortMergeJoinExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..7fc0e8a4adf6ae663263d7046d6e3444546d3a41 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/BaseColumnarSortMergeJoinExec.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, isSimpleColumn, isSimpleColumnForAll} +import nova.hetu.omniruntime.`type`.DataType +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.SQLMetrics + +/** + * Performs a sort merge join of two child relations. + */ +abstract class BaseColumnarSortMergeJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isSkewJoin: Boolean = false, + projectList: Seq[NamedExpression] = Seq.empty) + extends ShuffledJoin with CodegenSupport { + + override def supportsColumnar: Boolean = true + + override def supportCodegen: Boolean = false + + override def nodeName: String = { + if (isSkewJoin) "OmniColumnarSortMergeJoin(skew=true)" else "OmniColumnarSortMergeJoin" + } + + override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator + + override def requiredChildDistribution: Seq[Distribution] = { + if (isSkewJoin) { + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + } else { + super.requiredChildDistribution + } + } + + override def outputOrdering: Seq[SortOrder] = joinType match { + case _: InnerLike => + val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering) + val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering) + leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) => + val sameOrderExpressions = ExpressionSet(lKey.sameOrderExpressions ++ rKey.children) + SortOrder(lKey.child, Ascending, sameOrderExpressions.toSeq) + } + case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering) + case RightOuter => getKeyOrdering(rightKeys, right.outputOrdering) + case FullOuter => Nil + case LeftExistence(_) => getKeyOrdering(leftKeys, left.outputOrdering) + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + + private def getKeyOrdering(keys: Seq[Expression], childOutputOrdering: Seq[SortOrder]) + : Seq[SortOrder] = { + val requiredOrdering = requiredOrders(keys) + if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) { + keys.zip(childOutputOrdering).map { case (key, childOrder) => + val sameOrderExpressionSet = ExpressionSet(childOrder.children) - key + SortOrder(key, Ascending, sameOrderExpressionSet.toSeq) + } + } else { + requiredOrdering + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { + keys.map(SortOrder(_, Ascending)) + } + + override def output : Seq[Attribute] = { + if (projectList.nonEmpty) { + projectList.map(_.toAttribute) + } else { + super[ShuffledJoin].output + } + } + + override def needCopyResult: Boolean = true + + val SMJ_NEED_ADD_STREAM_TBL_DATA = 2 + val SMJ_NEED_ADD_BUFFERED_TBL_DATA = 3 + val SCAN_FINISH = 4 + + val RES_INIT = 0 + val SMJ_FETCH_JOIN_DATA = 5 + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "streamedAddInputTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time in omni streamed addInput"), + "streamedCodegenTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time in omni streamed codegen"), + "bufferedAddInputTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time in omni buffered addInput"), + "bufferedCodegenTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time in omni buffered codegen"), + "getOutputTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time in omni buffered getOutput"), + "numOutputVecBatchs" -> + SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"), + "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs"), + "bypassVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of bypass vecBatchs"), + "numStreamVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of streamed vecBatchs"), + "numBufferVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of buffered vecBatchs") + ) + + override def verboseStringWithOperatorId(): String = { + val joinCondStr = if (condition.isDefined) { + s"${condition.get}${condition.get.dataType}" + } else "None" + + s""" + |$formattedNodeName + |$simpleStringWithNodeId + |${ExplainUtils.generateFieldString("Stream input", left.output ++ left.output.map(_.dataType))} + |${ExplainUtils.generateFieldString("Buffer input", right.output ++ right.output.map(_.dataType))} + |${ExplainUtils.generateFieldString("Left keys", leftKeys ++ leftKeys.map(_.dataType))} + |${ExplainUtils.generateFieldString("Right keys", rightKeys ++ rightKeys.map(_.dataType))} + |${ExplainUtils.generateFieldString("Join condition", joinCondStr)} + |${ExplainUtils.generateFieldString("Project List", projectList ++ projectList.map(_.dataType))} + |${ExplainUtils.generateFieldString("Output", output ++ output.map(_.dataType))} + |Condition : $condition + |""".stripMargin + } + + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException(s"This operator doesn't support doExecute.") + } + + protected override def doProduce(ctx: CodegenContext): String = { + throw new UnsupportedOperationException(s"This operator doesn't support doProduce.") + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + left.execute() :: right.execute() :: Nil + } + + def buildCheck(): Unit = { + joinType match { + case Inner | LeftOuter | FullOuter | LeftSemi | LeftAnti => + // SMJ join support Inner | LeftOuter | FullOuter | LeftSemi | LeftAnti + case _ => + throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + + s"in ${this.nodeName}") + } + + val streamedTypes = new Array[DataType](left.output.size) + left.output.zipWithIndex.foreach { case (attr, i) => + streamedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val streamedKeyColsExp: Array[AnyRef] = leftKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(left.output.map(_.toAttribute))) + }.toArray + + val bufferedTypes = new Array[DataType](right.output.size) + right.output.zipWithIndex.foreach { case (attr, i) => + bufferedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val bufferedKeyColsExp: Array[AnyRef] = rightKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(right.output.map(_.toAttribute))) + }.toArray + + if (!isSimpleColumnForAll(streamedKeyColsExp.map(expr => expr.toString))) { + checkOmniJsonWhiteList("", streamedKeyColsExp) + } + + if (!isSimpleColumnForAll(bufferedKeyColsExp.map(expr => expr.toString))) { + checkOmniJsonWhiteList("", bufferedKeyColsExp) + } + + condition match { + case Some(expr) => + val filterExpr: String = OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, + OmniExpressionAdaptor.getExprIdMap((left.output ++ right.output).map(_.toAttribute))) + if (!isSimpleColumn(filterExpr)) { + checkOmniJsonWhiteList(filterExpr, new Array[AnyRef](0)) + } + case _ => null + } + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala index 48d0419c4188464399eb83e0602abd212ad11daf..5084614947dfe40d3321ca0e1db341407d1091f6 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala @@ -19,21 +19,20 @@ package org.apache.spark.sql.execution.joins import java.util.Optional import java.util.concurrent.TimeUnit.NANOSECONDS - import scala.collection.mutable - import com.huawei.boostkit.spark.ColumnarPluginConfig import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, isSimpleColumn, isSimpleColumnForAll} import com.huawei.boostkit.spark.util.OmniAdaptorUtil -import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getIndexArray, pruneOutput, reorderVecs, transColBatchToOmniVecs} +import nova.hetu.omniruntime.constants.JoinType._ import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.operator.OmniOperator import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} import nova.hetu.omniruntime.operator.join.{OmniHashBuilderWithExprOperatorFactory, OmniLookupJoinWithExprOperatorFactory} import nova.hetu.omniruntime.vector.VecBatch import nova.hetu.omniruntime.vector.serialize.VecBatchSerializerFactory - import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -42,7 +41,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{CodegenSupport, ColumnarHashedRelation, SparkPlan} +import org.apache.spark.sql.execution.{CodegenSupport, ColumnarHashedRelation, ExplainUtils, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils} import org.apache.spark.sql.execution.vectorized.OmniColumnVector @@ -62,9 +61,28 @@ case class ColumnarBroadcastHashJoinExec( condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isNullAwareAntiJoin: Boolean = false) + isNullAwareAntiJoin: Boolean = false, + projectList: Seq[NamedExpression] = Seq.empty) extends HashJoin { + override def verboseStringWithOperatorId(): String = { + val joinCondStr = if (condition.isDefined) { + s"${condition.get}${condition.get.dataType}" + } else "None" + s""" + |$formattedNodeName + |$simpleStringWithNodeId + |${ExplainUtils.generateFieldString("buildOutput", buildOutput ++ buildOutput.map(_.dataType))} + |${ExplainUtils.generateFieldString("streamedOutput", streamedOutput ++ streamedOutput.map(_.dataType))} + |${ExplainUtils.generateFieldString("leftKeys", leftKeys ++ leftKeys.map(_.dataType))} + |${ExplainUtils.generateFieldString("rightKeys", rightKeys ++ rightKeys.map(_.dataType))} + |${ExplainUtils.generateFieldString("condition", joinCondStr)} + |${ExplainUtils.generateFieldString("projectList", projectList.map(_.toAttribute) ++ projectList.map(_.toAttribute).map(_.dataType))} + |${ExplainUtils.generateFieldString("output", output ++ output.map(_.dataType))} + |Condition : $condition + |""".stripMargin + } + if (isNullAwareAntiJoin) { require(leftKeys.length == 1, "leftKeys length should be 1") require(rightKeys.length == 1, "rightKeys length should be 1") @@ -88,7 +106,8 @@ case class ColumnarBroadcastHashJoinExec( "buildCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni build codegen"), "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"), - "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs") + "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs"), + "bypassVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of bypass vecBatchs") ) override def supportsColumnar: Boolean = true @@ -109,7 +128,7 @@ case class ColumnarBroadcastHashJoinExec( override lazy val outputPartitioning: Partitioning = { joinType match { - case _: InnerLike if sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => + case Inner if sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => streamedPlan.outputPartitioning match { case h: HashPartitioning => expandOutputPartitioning(h) case c: PartitioningCollection => expandOutputPartitioning(c) @@ -201,7 +220,7 @@ case class ColumnarBroadcastHashJoinExec( def buildCheck(): Unit = { joinType match { - case LeftOuter | Inner => + case LeftOuter | Inner | LeftSemi => case _ => throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + s"in ${this.nodeName}") @@ -258,6 +277,7 @@ case class ColumnarBroadcastHashJoinExec( val numOutputRows = longMetric("numOutputRows") val numOutputVecBatchs = longMetric("numOutputVecBatchs") val numMergedVecBatchs = longMetric("numMergedVecBatchs") + val bypassVecBatchs = longMetric("bypassVecBatchs") val buildAddInputTime = longMetric("buildAddInputTime") val buildCodegenTime = longMetric("buildCodegenTime") val buildGetOutputTime = longMetric("buildGetOutputTime") @@ -270,60 +290,102 @@ case class ColumnarBroadcastHashJoinExec( buildTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) } + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val enableShareBuildOp: Boolean = columnarConf.enableShareBroadcastJoinHashTable + val enableJoinBatchMerge: Boolean = columnarConf.enableJoinBatchMerge + // {0}, buildKeys: col1#12 - val buildOutputCols = buildOutput.indices.toArray // {0,1} + val buildOutputCols: Array[Int] = joinType match { + case Inner | LeftOuter => + getIndexArray(buildOutput, projectList) + case LeftExistence(_) => + Array[Int]() + case x => + throw new UnsupportedOperationException(s"ColumnBroadcastHashJoin Join-type[$x] is not supported!") + } val buildJoinColsExp = buildKeys.map { x => OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute))) }.toArray val relation = buildPlan.executeBroadcast[ColumnarHashedRelation]() - val buildOutputTypes = buildTypes // {1,1} + val prunedBuildOutput = pruneOutput(buildOutput, projectList) + val buildOutputTypes = new Array[DataType](prunedBuildOutput.size) // {2,2}, buildOutput:col1#12,col2#13 + prunedBuildOutput.zipWithIndex.foreach { case (att, i) => + buildOutputTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + } val probeTypes = new Array[DataType](streamedOutput.size) // {2,2}, streamedOutput:col1#10,col2#11 streamedOutput.zipWithIndex.foreach { case (attr, i) => probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) } - val probeOutputCols = streamedOutput.indices.toArray // {0,1} + val probeOutputCols = getIndexArray(streamedOutput, projectList) // {0,1} val probeHashColsExp = streamedKeys.map { x => OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute))) }.toArray + + val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType) + val canShareBuildOp = (lookupJoinType != OMNI_JOIN_TYPE_RIGHT && lookupJoinType != OMNI_JOIN_TYPE_FULL) streamedPlan.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) => val filter: Optional[String] = condition match { case Some(expr) => Optional.of(OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, OmniExpressionAdaptor.getExprIdMap((streamedOutput ++ buildOutput).map(_.toAttribute)))) - case _ => Optional.empty() + case _ => + Optional.empty() } - val startBuildCodegen = System.nanoTime() - val buildOpFactory = - new OmniHashBuilderWithExprOperatorFactory(buildTypes, buildJoinColsExp, filter, 1, - new OperatorConfig(SpillConfig.NONE, - new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) - val buildOp = buildOpFactory.createOperator() - buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen) - // close operator - SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { - buildOp.close() - buildOpFactory.close() - }) + def createBuildOpFactoryAndOp(): (OmniHashBuilderWithExprOperatorFactory, OmniOperator) = { + val startBuildCodegen = System.nanoTime() + val opFactory = + new OmniHashBuilderWithExprOperatorFactory(buildTypes, buildJoinColsExp, 1, + new OperatorConfig(SpillConfig.NONE, + new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + val op = opFactory.createOperator() + buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen) + + val deserializer = VecBatchSerializerFactory.create() + relation.value.buildData.foreach { input => + val startBuildInput = System.nanoTime() + op.addInput(deserializer.deserialize(input)) + buildAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildInput) + } + val startBuildGetOp = System.nanoTime() + op.getOutput + buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp) + (opFactory, op) + } - val deserializer = VecBatchSerializerFactory.create() - relation.value.buildData.foreach { input => - val startBuildInput = System.nanoTime() - buildOp.addInput(deserializer.deserialize(input)) - buildAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildInput) + var buildOp: OmniOperator = null + var buildOpFactory: OmniHashBuilderWithExprOperatorFactory = null + if (enableShareBuildOp && canShareBuildOp) { + OmniHashBuilderWithExprOperatorFactory.gLock.lock() + try { + buildOpFactory = OmniHashBuilderWithExprOperatorFactory.getHashBuilderOperatorFactory(buildPlan.id) + if (buildOpFactory == null) { + val (opFactory, op) = createBuildOpFactoryAndOp() + buildOpFactory = opFactory + buildOp = op + OmniHashBuilderWithExprOperatorFactory.saveHashBuilderOperatorAndFactory(buildPlan.id, + buildOpFactory, buildOp) + } + } catch { + case e: Exception => { + throw new RuntimeException("hash build failed. errmsg:" + e.getMessage()) + } + } finally { + OmniHashBuilderWithExprOperatorFactory.gLock.unlock() + } + } else { + val (opFactory, op) = createBuildOpFactoryAndOp() + buildOpFactory = opFactory + buildOp = op } - val startBuildGetOp = System.nanoTime() - buildOp.getOutput - buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp) val startLookupCodegen = System.nanoTime() - val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType) val lookupOpFactory = new OmniLookupJoinWithExprOperatorFactory(probeTypes, probeOutputCols, - probeHashColsExp, buildOutputCols, buildOutputTypes, lookupJoinType, buildOpFactory, + probeHashColsExp, buildOutputCols, buildOutputTypes, lookupJoinType, buildOpFactory, filter, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp = lookupOpFactory.createOperator() @@ -333,23 +395,31 @@ case class ColumnarBroadcastHashJoinExec( SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { lookupOp.close() lookupOpFactory.close() + if (enableShareBuildOp && canShareBuildOp) { + OmniHashBuilderWithExprOperatorFactory.gLock.lock() + OmniHashBuilderWithExprOperatorFactory.dereferenceHashBuilderOperatorAndFactory(buildPlan.id) + OmniHashBuilderWithExprOperatorFactory.gLock.unlock() + } else { + buildOp.close() + buildOpFactory.close() + } }) + val streamedPlanOutput = pruneOutput(streamedPlan.output, projectList) + val prunedOutput = streamedPlanOutput ++ prunedBuildOutput val resultSchema = this.schema val reverse = buildSide == BuildLeft var left = 0 - var leftLen = streamedPlan.output.size - var right = streamedPlan.output.size + var leftLen = streamedPlanOutput.size + var right = streamedPlanOutput.size var rightLen = output.size if (reverse) { - left = streamedPlan.output.size + left = streamedPlanOutput.size leftLen = output.size right = 0 - rightLen = streamedPlan.output.size + rightLen = streamedPlanOutput.size } - val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf - val enableJoinBatchMerge: Boolean = columnarConf.enableJoinBatchMerge val iterBatch = new Iterator[ColumnarBatch] { private var results: java.util.Iterator[VecBatch] = _ var res: Boolean = true @@ -392,27 +462,33 @@ case class ColumnarBroadcastHashJoinExec( val resultVecs = result.getVectors val vecs = OmniColumnVector .allocateColumns(result.getRowCount, resultSchema, false) - var index = 0 - for (i <- left until leftLen) { - val v = vecs(index) - v.reset() - v.setVec(resultVecs(i)) - index += 1 - } - for (i <- right until rightLen) { - val v = vecs(index) - v.reset() - v.setVec(resultVecs(i)) - index += 1 + if (projectList.nonEmpty) { + reorderVecs(prunedOutput, projectList, resultVecs, vecs) + } else { + var index = 0 + for (i <- left until leftLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } + for (i <- right until rightLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } } - numOutputRows += result.getRowCount + val rowCnt: Int = result.getRowCount + numOutputRows += rowCnt numOutputVecBatchs += 1 - new ColumnarBatch(vecs.toArray, result.getRowCount) + result.close() + new ColumnarBatch(vecs.toArray, rowCnt) } } if (enableJoinBatchMerge) { - new MergeIterator(iterBatch, resultSchema, numMergedVecBatchs) + new MergeIterator(iterBatch, resultSchema, numMergedVecBatchs, bypassVecBatchs) } else { iterBatch } @@ -428,7 +504,7 @@ case class ColumnarBroadcastHashJoinExec( } private def multipleOutputForOneInput: Boolean = joinType match { - case _: InnerLike | LeftOuter | RightOuter => + case Inner | LeftOuter | RightOuter => // For inner and outer joins, one row from the streamed side may produce multiple result rows, // if the build side has duplicated keys. Note that here we wait for the broadcast to be // finished, which is a no-op because it's already finished when we wait it in `doProduce`. @@ -458,4 +534,27 @@ case class ColumnarBroadcastHashJoinExec( protected override def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = { throw new UnsupportedOperationException(s"This operator doesn't support codegenAnti().") } -} \ No newline at end of file + + override def output: Seq[Attribute] = { + if (projectList.nonEmpty) { + projectList.map(_.toAttribute) + } else { + joinType match { + case Inner => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case j: ExistenceJoin => + left.output :+ j.exists + case LeftExistence(_) => + left.output + case x => + throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType") + } + } + } + + +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala index 33fb61a79828d1c6c28956b83312344ebe8279b6..722c166ec646b5caef18539afca4517d6fd34f95 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala @@ -19,25 +19,23 @@ package org.apache.spark.sql.execution.joins import java.util.Optional import java.util.concurrent.TimeUnit.NANOSECONDS - import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, isSimpleColumn, isSimpleColumnForAll} import com.huawei.boostkit.spark.util.OmniAdaptorUtil -import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getIndexArray, pruneOutput, reorderVecs, transColBatchToOmniVecs} import nova.hetu.omniruntime.`type`.DataType import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} -import nova.hetu.omniruntime.operator.join._ +import nova.hetu.omniruntime.operator.join.{OmniHashBuilderWithExprOperatorFactory, OmniLookupJoinWithExprOperatorFactory, OmniLookupOuterJoinWithExprOperatorFactory} import nova.hetu.omniruntime.vector.VecBatch - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildSide} -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType} +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, Inner, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.util.SparkMemoryUtils import org.apache.spark.sql.execution.vectorized.OmniColumnVector @@ -50,9 +48,28 @@ case class ColumnarShuffledHashJoinExec( buildSide: BuildSide, condition: Option[Expression], left: SparkPlan, - right: SparkPlan) + right: SparkPlan, + projectList: Seq[NamedExpression] = Seq.empty) extends HashJoin with ShuffledJoin { + override def verboseStringWithOperatorId(): String = { + val joinCondStr = if (condition.isDefined) { + s"${condition.get}${condition.get.dataType}" + } else "None" + s""" + |$formattedNodeName + |$simpleStringWithNodeId + |${ExplainUtils.generateFieldString("buildOutput", buildOutput ++ buildOutput.map(_.dataType))} + |${ExplainUtils.generateFieldString("streamedOutput", streamedOutput ++ streamedOutput.map(_.dataType))} + |${ExplainUtils.generateFieldString("leftKeys", leftKeys ++ leftKeys.map(_.dataType))} + |${ExplainUtils.generateFieldString("rightKeys", rightKeys ++ rightKeys.map(_.dataType))} + |${ExplainUtils.generateFieldString("condition", joinCondStr)} + |${ExplainUtils.generateFieldString("projectList", projectList.map(_.toAttribute) ++ projectList.map(_.toAttribute).map(_.dataType))} + |${ExplainUtils.generateFieldString("output", output ++ output.map(_.dataType))} + |Condition : $condition + |""".stripMargin + } + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "lookupAddInputTime" -> SQLMetrics.createTimingMetric(sparkContext, @@ -77,7 +94,13 @@ case class ColumnarShuffledHashJoinExec( override def nodeName: String = "OmniColumnarShuffledHashJoin" - override def output: Seq[Attribute] = super[ShuffledJoin].output + override def output: Seq[Attribute] = { + if (projectList.nonEmpty) { + projectList.map(_.toAttribute) + } else { + super[ShuffledJoin].output + } + } override def outputPartitioning: Partitioning = super[ShuffledJoin].outputPartitioning @@ -92,7 +115,7 @@ case class ColumnarShuffledHashJoinExec( def buildCheck(): Unit = { joinType match { - case FullOuter | Inner => + case FullOuter | Inner | LeftAnti | LeftOuter | LeftSemi => case _ => throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + s"in ${this.nodeName}") @@ -156,17 +179,32 @@ case class ColumnarShuffledHashJoinExec( buildOutput.zipWithIndex.foreach { case (att, i) => buildTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) } - val buildOutputCols = buildOutput.indices.toArray + + val buildOutputCols: Array[Int] = joinType match { + case Inner | FullOuter | LeftOuter => + getIndexArray(buildOutput, projectList) + case LeftExistence(_) => + Array[Int]() + case x => + throw new UnsupportedOperationException(s"ColumnShuffledHashJoin Join-type[$x] is not supported!") + } + val buildJoinColsExp = buildKeys.map { x => OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute))) }.toArray + val prunedBuildOutput = pruneOutput(buildOutput, projectList) + val buildOutputTypes = new Array[DataType](prunedBuildOutput.size) // {2,2}, buildOutput:col1#12,col2#13 + prunedBuildOutput.zipWithIndex.foreach { case (att, i) => + buildOutputTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + } + val probeTypes = new Array[DataType](streamedOutput.size) streamedOutput.zipWithIndex.foreach { case (attr, i) => probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) } - val probeOutputCols = streamedOutput.indices.toArray + val probeOutputCols = getIndexArray(streamedOutput, projectList) val probeHashColsExp = streamedKeys.map { x => OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute))) @@ -183,13 +221,24 @@ case class ColumnarShuffledHashJoinExec( } val startBuildCodegen = System.nanoTime() val buildOpFactory = new OmniHashBuilderWithExprOperatorFactory(buildTypes, - buildJoinColsExp, filter, 1, new OperatorConfig(SpillConfig.NONE, + buildJoinColsExp, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp = buildOpFactory.createOperator() buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen) + val startLookupCodegen = System.nanoTime() + val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType) + val lookupOpFactory = new OmniLookupJoinWithExprOperatorFactory(probeTypes, + probeOutputCols, probeHashColsExp, buildOutputCols, buildOutputTypes, lookupJoinType, + buildOpFactory, filter, new OperatorConfig(SpillConfig.NONE, + new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + val lookupOp = lookupOpFactory.createOperator() + lookupCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupCodegen) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + lookupOp.close() buildOp.close() + lookupOpFactory.close() buildOpFactory.close() }) @@ -210,32 +259,19 @@ case class ColumnarShuffledHashJoinExec( buildOp.getOutput buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp) - val startLookupCodegen = System.nanoTime() - val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType) - val lookupOpFactory = new OmniLookupJoinWithExprOperatorFactory(probeTypes, - probeOutputCols, probeHashColsExp, buildOutputCols, buildTypes, lookupJoinType, - buildOpFactory, new OperatorConfig(SpillConfig.NONE, - new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) - - val lookupOp = lookupOpFactory.createOperator() - lookupCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupCodegen) - - SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { - lookupOp.close() - lookupOpFactory.close() - }) - + val streamedPlanOutput = pruneOutput(streamedPlan.output, projectList) + val prunedOutput = streamedPlanOutput ++ prunedBuildOutput val resultSchema = this.schema val reverse = buildSide == BuildLeft var left = 0 - var leftLen = streamedPlan.output.size - var right = streamedPlan.output.size + var leftLen = streamedPlanOutput.size + var right = streamedPlanOutput.size var rightLen = output.size if (reverse) { - left = streamedPlan.output.size + left = streamedPlanOutput.size leftLen = output.size right = 0 - rightLen = streamedPlan.output.size + rightLen = streamedPlanOutput.size } val joinIter: Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] { @@ -278,28 +314,34 @@ case class ColumnarShuffledHashJoinExec( val resultVecs = result.getVectors val vecs = OmniColumnVector .allocateColumns(result.getRowCount, resultSchema, false) - var index = 0 - for (i <- left until leftLen) { - val v = vecs(index) - v.reset() - v.setVec(resultVecs(i)) - index += 1 - } - for (i <- right until rightLen) { - val v = vecs(index) - v.reset() - v.setVec(resultVecs(i)) - index += 1 + if (projectList.nonEmpty) { + reorderVecs(prunedOutput, projectList, resultVecs, vecs) + } else { + var index = 0 + for (i <- left until leftLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } + for (i <- right until rightLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } } - numOutputRows += result.getRowCount + val rowCnt: Int = result.getRowCount + numOutputRows += rowCnt numOutputVecBatchs += 1 - new ColumnarBatch(vecs.toArray, result.getRowCount) + result.close() + new ColumnarBatch(vecs.toArray, rowCnt) } } if ("FULL OUTER" == joinType.sql) { val lookupOuterOpFactory = new OmniLookupOuterJoinWithExprOperatorFactory(probeTypes, probeOutputCols, - probeHashColsExp, buildOutputCols, buildTypes, buildOpFactory, + probeHashColsExp, buildOutputCols, buildOutputTypes, buildOpFactory, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) @@ -325,18 +367,22 @@ case class ColumnarShuffledHashJoinExec( val resultVecs = result.getVectors val vecs = OmniColumnVector .allocateColumns(result.getRowCount, resultSchema, false) - var index = 0 - for (i <- left until leftLen) { - val v = vecs(index) - v.reset() - v.setVec(resultVecs(i)) - index += 1 - } - for (i <- right until rightLen) { - val v = vecs(index) - v.reset() - v.setVec(resultVecs(i)) - index += 1 + if (projectList.nonEmpty) { + reorderVecs(prunedOutput, projectList, resultVecs, vecs) + } else { + var index = 0 + for (i <- left until leftLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } + for (i <- right until rightLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } } numOutputRows += result.getRowCount numOutputVecBatchs += 1 diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala index b538a8613549b4986b968e2973434e79a6b4d38c..4a91517d5abf7122b7bcafcc4bfc3036a292e875 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala @@ -25,9 +25,8 @@ import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, isSimpleColumn, isSimpleColumnForAll} import com.huawei.boostkit.spark.util.OmniAdaptorUtil -import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getIndexArray, pruneOutput, reorderVecs, transColBatchToOmniVecs} import nova.hetu.omniruntime.`type`.DataType -import nova.hetu.omniruntime.constants.JoinType._ import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} import nova.hetu.omniruntime.operator.join.{OmniSmjBufferedTableWithExprOperatorFactory, OmniSmjStreamedTableWithExprOperatorFactory} import nova.hetu.omniruntime.vector.{BooleanVec, Decimal128Vec, DoubleVec, IntVec, LongVec, VarcharVec, Vec, VecBatch, ShortVec} @@ -35,7 +34,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils} import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.vectorized.ColumnarBatch @@ -43,105 +41,22 @@ import org.apache.spark.sql.vectorized.ColumnarBatch /** * Performs a sort merge join of two child relations. */ -class ColumnarSortMergeJoinExec( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan, - isSkewJoin: Boolean = false) - extends SortMergeJoinExec( +case class ColumnarSortMergeJoinExec( leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isSkewJoin: Boolean) with CodegenSupport { - - override def supportsColumnar: Boolean = true - - override def supportCodegen: Boolean = false - - override def nodeName: String = { - if (isSkewJoin) "OmniColumnarSortMergeJoin(skew=true)" else "OmniColumnarSortMergeJoin" - } - - val SMJ_NEED_ADD_STREAM_TBL_DATA = 2 - val SMJ_NEED_ADD_BUFFERED_TBL_DATA = 3 - val SMJ_NO_RESULT = 4 - val SMJ_FETCH_JOIN_DATA = 5 - - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "streamedAddInputTime" -> - SQLMetrics.createMetric(sparkContext, "time in omni streamed addInput"), - "streamedCodegenTime" -> - SQLMetrics.createMetric(sparkContext, "time in omni streamed codegen"), - "bufferedAddInputTime" -> - SQLMetrics.createMetric(sparkContext, "time in omni buffered addInput"), - "bufferedCodegenTime" -> - SQLMetrics.createMetric(sparkContext, "time in omni buffered codegen"), - "getOutputTime" -> - SQLMetrics.createMetric(sparkContext, "time in omni buffered getOutput"), - "numOutputVecBatchs" -> - SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"), - "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs"), - "numStreamVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of streamed vecBatchs"), - "numBufferVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of buffered vecBatchs") - ) - - def buildCheck(): Unit = { - joinType match { - case _: InnerLike | LeftOuter | FullOuter => - // SMJ join support InnerLike | LeftOuter | FullOuter - case _ => - throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + - s"in ${this.nodeName}") - } - - val streamedTypes = new Array[DataType](left.output.size) - left.output.zipWithIndex.foreach { case (attr, i) => - streamedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) - } - val streamedKeyColsExp: Array[AnyRef] = leftKeys.map { x => - OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, - OmniExpressionAdaptor.getExprIdMap(left.output.map(_.toAttribute))) - }.toArray - - val bufferedTypes = new Array[DataType](right.output.size) - right.output.zipWithIndex.foreach { case (attr, i) => - bufferedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) - } - val bufferedKeyColsExp: Array[AnyRef] = rightKeys.map { x => - OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, - OmniExpressionAdaptor.getExprIdMap(right.output.map(_.toAttribute))) - }.toArray - - if (!isSimpleColumnForAll(streamedKeyColsExp.map(expr => expr.toString))) { - checkOmniJsonWhiteList("", streamedKeyColsExp) - } - - if (!isSimpleColumnForAll(bufferedKeyColsExp.map(expr => expr.toString))) { - checkOmniJsonWhiteList("", bufferedKeyColsExp) - } - - condition match { - case Some(expr) => - val filterExpr: String = OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, - OmniExpressionAdaptor.getExprIdMap(output.map(_.toAttribute))) - if (!isSimpleColumn(filterExpr)) { - checkOmniJsonWhiteList(filterExpr, new Array[AnyRef](0)) - } - case _ => null - } - } + isSkewJoin: Boolean = false, + projectList: Seq[NamedExpression] = Seq.empty) + extends BaseColumnarSortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right, isSkewJoin, projectList) { override def doExecuteColumnar(): RDD[ColumnarBatch] = { val numOutputRows = longMetric("numOutputRows") val numOutputVecBatchs = longMetric("numOutputVecBatchs") val numMergedVecBatchs = longMetric("numMergedVecBatchs") + val bypassVecBatchs = longMetric("bypassVecBatchs") val streamedAddInputTime = longMetric("streamedAddInputTime") val streamedCodegenTime = longMetric("streamedCodegenTime") val bufferedAddInputTime = longMetric("bufferedAddInputTime") @@ -150,15 +65,6 @@ class ColumnarSortMergeJoinExec( val streamVecBatchs = longMetric("numStreamVecBatchs") val bufferVecBatchs = longMetric("numBufferVecBatchs") - val omniJoinType : nova.hetu.omniruntime.constants.JoinType = joinType match { - case _: InnerLike => OMNI_JOIN_TYPE_INNER - case LeftOuter => OMNI_JOIN_TYPE_LEFT - case FullOuter => OMNI_JOIN_TYPE_FULL - case x => - throw new UnsupportedOperationException(s"ColumnSortMergeJoin Join-type[$x] is not supported " + - s"in ${this.nodeName}") - } - val streamedTypes = new Array[DataType](left.output.size) left.output.zipWithIndex.foreach { case (attr, i) => streamedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) @@ -167,7 +73,7 @@ class ColumnarSortMergeJoinExec( OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, OmniExpressionAdaptor.getExprIdMap(left.output.map(_.toAttribute))) }.toArray - val streamedOutputChannel = left.output.indices.toArray + val streamedOutputChannel = getIndexArray(left.output, projectList) val bufferedTypes = new Array[DataType](right.output.size) right.output.zipWithIndex.foreach { case (attr, i) => @@ -177,12 +83,19 @@ class ColumnarSortMergeJoinExec( OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, OmniExpressionAdaptor.getExprIdMap(right.output.map(_.toAttribute))) }.toArray - val bufferedOutputChannel = right.output.indices.toArray + val bufferedOutputChannel: Array[Int] = joinType match { + case Inner | LeftOuter | FullOuter => + getIndexArray(right.output, projectList) + case LeftExistence(_) => + Array[Int]() + case x => + throw new UnsupportedOperationException(s"ColumnSortMergeJoin Join-type[$x] is not supported!") + } val filterString: String = condition match { case Some(expr) => OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, - OmniExpressionAdaptor.getExprIdMap(output.map(_.toAttribute))) + OmniExpressionAdaptor.getExprIdMap((left.output ++ right.output).map(_.toAttribute))) case _ => null } @@ -214,14 +127,17 @@ class ColumnarSortMergeJoinExec( streamedOpFactory.close() }) + val prunedStreamOutput = pruneOutput(left.output, projectList) + val prunedBufferOutput = pruneOutput(right.output, projectList) + val prunedOutput = prunedStreamOutput ++ prunedBufferOutput val resultSchema = this.schema val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf val enableSortMergeJoinBatchMerge: Boolean = columnarConf.enableSortMergeJoinBatchMerge val iterBatch = new Iterator[ColumnarBatch] { var isFinished : Boolean = joinType match { - case _: InnerLike => !streamedIter.hasNext || !bufferedIter.hasNext - case LeftOuter => !streamedIter.hasNext + case Inner | LeftSemi => !streamedIter.hasNext || !bufferedIter.hasNext + case LeftOuter | LeftAnti => !streamedIter.hasNext case FullOuter => !(streamedIter.hasNext || bufferedIter.hasNext) case x => throw new UnsupportedOperationException(s"ColumnSortMergeJoin Join-type[$x] is not supported!") @@ -230,16 +146,30 @@ class ColumnarSortMergeJoinExec( var isStreamedFinished = false var isBufferedFinished = false var results: java.util.Iterator[VecBatch] = null + var flowControlCode: Int = SMJ_NEED_ADD_STREAM_TBL_DATA + var resCode: Int = RES_INIT def checkAndClose() : Unit = { - while (streamedIter.hasNext) { - streamVecBatchs += 1 - streamedIter.next().close() - } - while(bufferedIter.hasNext) { - bufferVecBatchs += 1 - bufferedIter.next().close() - } + while (streamedIter.hasNext) { + streamVecBatchs += 1 + streamedIter.next().close() + } + while(bufferedIter.hasNext) { + bufferVecBatchs += 1 + bufferedIter.next().close() + } + } + + // FLOW_CONTROL_CODE has 3 values: 2,3,4 + // 2-> add streamTable data + // 3-> add buffedTable data + // 4-> streamTable and buffedTable scan is finished + // RES_CODE has 2 values: 0,5 + // 0-> init status code, it means no result to fetch + // 5-> operator produced result data, we should fetch data + def decodeOpStatus(code: Int): Unit = { + flowControlCode = code >> 16 + resCode = code & 0xFFFF } override def hasNext: Boolean = { @@ -250,21 +180,20 @@ class ColumnarSortMergeJoinExec( if (results != null && results.hasNext) { return true } - // reset results and find next results + // reset results and RES_CODE results = null - // Add streamed data first - var inputReturnCode = SMJ_NEED_ADD_STREAM_TBL_DATA - while (inputReturnCode == SMJ_NEED_ADD_STREAM_TBL_DATA - || inputReturnCode == SMJ_NEED_ADD_BUFFERED_TBL_DATA) { - if (inputReturnCode == SMJ_NEED_ADD_STREAM_TBL_DATA) { + resCode = RES_INIT + // add data until operator produce results or scan is finished + while (resCode == RES_INIT && flowControlCode != SCAN_FINISH){ + if (flowControlCode == SMJ_NEED_ADD_STREAM_TBL_DATA) { val startBuildStreamedInput = System.nanoTime() if (!isStreamedFinished && streamedIter.hasNext) { val batch = streamedIter.next() streamVecBatchs += 1 val inputVecBatch = transColBatchToVecBatch(batch) - inputReturnCode = streamedOp.addInput(inputVecBatch) + decodeOpStatus(streamedOp.addInput(inputVecBatch)) } else { - inputReturnCode = streamedOp.addInput(createEofVecBatch(streamedTypes)) + decodeOpStatus(streamedOp.addInput(createEofVecBatch(streamedTypes))) isStreamedFinished = true } streamedAddInputTime += @@ -275,38 +204,31 @@ class ColumnarSortMergeJoinExec( val batch = bufferedIter.next() bufferVecBatchs += 1 val inputVecBatch = transColBatchToVecBatch(batch) - inputReturnCode = bufferedOp.addInput(inputVecBatch) + decodeOpStatus(bufferedOp.addInput(inputVecBatch)) } else { - inputReturnCode = bufferedOp.addInput(createEofVecBatch(bufferedTypes)) + decodeOpStatus(bufferedOp.addInput(createEofVecBatch(bufferedTypes))) isBufferedFinished = true } bufferedAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildBufferedInput) } } - if (inputReturnCode == SMJ_FETCH_JOIN_DATA) { + if (resCode == SMJ_FETCH_JOIN_DATA) { val startGetOutputTime = System.nanoTime() results = bufferedOp.getOutput val hasNext = results.hasNext getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOutputTime) - if (hasNext) { - return true - } else { - isFinished = true - results = null - checkAndClose() - return false - } + return hasNext } - if (inputReturnCode == SMJ_NO_RESULT) { + if (flowControlCode == SCAN_FINISH) { isFinished = true results = null checkAndClose() return false } - throw new UnsupportedOperationException(s"Unknown return code ${inputReturnCode}") + throw new UnsupportedOperationException(s"Unknown return code ${flowControlCode},${resCode} ") } override def next(): ColumnarBatch = { @@ -315,10 +237,14 @@ class ColumnarSortMergeJoinExec( getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOutputTime) val resultVecs = result.getVectors val vecs = OmniColumnVector.allocateColumns(result.getRowCount, resultSchema, false) - for (index <- output.indices) { - val v = vecs(index) - v.reset() - v.setVec(resultVecs(index)) + if (projectList.nonEmpty) { + reorderVecs(prunedOutput, projectList, resultVecs, vecs) + } else { + for (index <- output.indices) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(index)) + } } numOutputVecBatchs += 1 numOutputRows += result.getRowCount @@ -339,7 +265,7 @@ class ColumnarSortMergeJoinExec( case DataType.DataTypeId.OMNI_BOOLEAN => new BooleanVec(0) case DataType.DataTypeId.OMNI_CHAR | DataType.DataTypeId.OMNI_VARCHAR => - new VarcharVec(0, 0) + new VarcharVec(0) case DataType.DataTypeId.OMNI_DECIMAL128 => new Decimal128Vec(0) case DataType.DataTypeId.OMNI_SHORT => @@ -359,7 +285,7 @@ class ColumnarSortMergeJoinExec( } if (enableSortMergeJoinBatchMerge) { - new MergeIterator(iterBatch, resultSchema, numMergedVecBatchs) + new MergeIterator(iterBatch, resultSchema, numMergedVecBatchs, bypassVecBatchs) } else { iterBatch } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinFusionExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinFusionExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..7f7d90be90a72c475c309d1c9299d2c0bdd168de --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinFusionExec.scala @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.concurrent.TimeUnit.NANOSECONDS +import java.util.Optional + +import com.huawei.boostkit.spark.ColumnarPluginConfig +import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor +import com.huawei.boostkit.spark.util.OmniAdaptorUtil +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getIndexArray, pruneOutput, reorderVecs, transColBatchToOmniVecs} +import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} +import nova.hetu.omniruntime.operator.join.{OmniSmjBufferedTableWithExprOperatorFactoryV3, OmniSmjStreamedTableWithExprOperatorFactoryV3} +import nova.hetu.omniruntime.vector.VecBatch +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils} +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Performs a sort merge join of two child relations. + */ +case class ColumnarSortMergeJoinFusionExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isSkewJoin: Boolean = false, + projectList: Seq[NamedExpression] = Seq.empty) + extends BaseColumnarSortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right, isSkewJoin, projectList) { + + override def nodeName: String = { + if (isSkewJoin) "OmniColumnarSortMergeJoinFusion(skew=true)" else "OmniColumnarSortMergeJoinFusion" + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val numMergedVecBatchs = longMetric("numMergedVecBatchs") + val bypassVecBatchs = longMetric("bypassVecBatchs") + val streamedAddInputTime = longMetric("streamedAddInputTime") + val streamedCodegenTime = longMetric("streamedCodegenTime") + val bufferedAddInputTime = longMetric("bufferedAddInputTime") + val bufferedCodegenTime = longMetric("bufferedCodegenTime") + val getOutputTime = longMetric("getOutputTime") + val streamVecBatchs = longMetric("numStreamVecBatchs") + val bufferVecBatchs = longMetric("numBufferVecBatchs") + + val streamedTypes = new Array[DataType](left.output.size) + left.output.zipWithIndex.foreach { case (attr, i) => + streamedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val streamedKeyColsExp = leftKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(left.output.map(_.toAttribute))) + }.toArray + val streamedOutputChannel = getIndexArray(left.output, projectList) + + val bufferedTypes = new Array[DataType](right.output.size) + right.output.zipWithIndex.foreach { case (attr, i) => + bufferedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val bufferedKeyColsExp = rightKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(right.output.map(_.toAttribute))) + }.toArray + val bufferedOutputChannel: Array[Int] = joinType match { + case Inner | LeftOuter | FullOuter => + getIndexArray(right.output, projectList) + case LeftExistence(_) => + Array[Int]() + case x => + throw new UnsupportedOperationException(s"ColumnSortMergeJoin Join-type[$x] is not supported!") + } + + val filterString: String = condition match { + case Some(expr) => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, + OmniExpressionAdaptor.getExprIdMap((left.output ++ right.output).map(_.toAttribute))) + case _ => null + } + + left.executeColumnar().zipPartitions(right.executeColumnar()) { (streamedIter, bufferedIter) => + val filter: Optional[String] = Optional.ofNullable(filterString) + val startStreamedCodegen = System.nanoTime() + val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType) + val streamedOpFactory = new OmniSmjStreamedTableWithExprOperatorFactoryV3(streamedTypes, + streamedKeyColsExp, streamedOutputChannel, lookupJoinType, filter, + new OperatorConfig(SpillConfig.NONE, + new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + + val streamedOp = streamedOpFactory.createOperator + streamedCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startStreamedCodegen) + + val startBufferedCodegen = System.nanoTime() + val bufferedOpFactory = new OmniSmjBufferedTableWithExprOperatorFactoryV3(bufferedTypes, + bufferedKeyColsExp, bufferedOutputChannel, streamedOpFactory, + new OperatorConfig(SpillConfig.NONE, + new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + val bufferedOp = bufferedOpFactory.createOperator + bufferedCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBufferedCodegen) + + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + bufferedOp.close() + streamedOp.close() + bufferedOpFactory.close() + streamedOpFactory.close() + }) + + while (bufferedIter.hasNext) { + val cb = bufferedIter.next() + val vecs = transColBatchToOmniVecs(cb, false) + val startBuildInput = System.nanoTime() + bufferedOp.addInput(new VecBatch(vecs, cb.numRows())) + bufferedAddInputTime += NANOSECONDS.toMillis(System.nanoTime() -startBuildInput) + } + + while (streamedIter.hasNext) { + val cb = streamedIter.next() + val vecs = transColBatchToOmniVecs(cb, false) + val startBuildInput = System.nanoTime() + streamedOp.addInput(new VecBatch(vecs, cb.numRows())) + streamedAddInputTime += NANOSECONDS.toMillis(System.nanoTime() -startBuildInput) + } + + val prunedStreamOutput = pruneOutput(left.output, projectList) + val prunedBufferOutput = pruneOutput(right.output, projectList) + val prunedOutput = prunedStreamOutput ++ prunedBufferOutput + val resultSchema = this.schema + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val enableSortMergeJoinBatchMerge: Boolean = columnarConf.enableSortMergeJoinBatchMerge + + val startGetOutputTime: Long = System.nanoTime() + val results: java.util.Iterator[VecBatch] = bufferedOp.getOutput + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() -startGetOutputTime) + + val iterBatch = new Iterator[ColumnarBatch] { + override def hasNext: Boolean = { + val startGetOp: Long = System.nanoTime() + val hasNext = results.hasNext + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() -startGetOp) + hasNext + } + + override def next(): ColumnarBatch = { + val startGetOp: Long = System.nanoTime() + val result: VecBatch = results.next() + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + + val resultVecs =result.getVectors + val vecs = OmniColumnVector.allocateColumns(result.getRowCount, resultSchema, false) + if (projectList.nonEmpty) { + reorderVecs(prunedOutput, projectList, resultVecs, vecs) + } else { + for (index <- output.indices) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(index)) + } + } + numOutputVecBatchs += 1 + numOutputRows += result.getRowCount + result.close() + new ColumnarBatch(vecs.toArray, result.getRowCount) + } + } + + if (enableSortMergeJoinBatchMerge) { + new MergeIterator(iterBatch, resultSchema, numMergedVecBatchs, bypassVecBatchs) + } else { + iterBatch + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala index c67d45032589b74ee414625010cba01ba716465b..6236aefee7c8836c8a244e34e004da8f641f6e84 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala @@ -31,13 +31,14 @@ import org.apache.spark.sql.types.{BooleanType, DateType, DecimalType, DoubleTyp import org.apache.spark.sql.vectorized.ColumnarBatch class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, - numMergedVecBatchs: SQLMetric) extends Iterator[ColumnarBatch] { + numMergedVecBatchs: SQLMetric, bypassVecBatchs: SQLMetric) extends Iterator[ColumnarBatch] { private val outputQueue = new mutable.Queue[VecBatch] private val bufferedVecBatch = new ListBuffer[VecBatch]() val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf private val maxBatchSizeInBytes: Int = columnarConf.maxBatchSizeInBytes private val maxRowCount: Int = columnarConf.maxRowCount + private val mergedBatchThreshold: Int = columnarConf.mergedBatchThreshold private var totalRows = 0 private var currentBatchSizeInBytes = 0 @@ -57,8 +58,7 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, vecs(index) = new BooleanVec(columnSize) case StringType => val vecType: DataType = sparkTypeToOmniType(field.dataType, field.metadata) - vecs(index) = new VarcharVec(vecType.asInstanceOf[VarcharDataType].getWidth * columnSize, - columnSize) + vecs(index) = new VarcharVec(columnSize) case dt: DecimalType => if (DecimalType.is64BitDecimalType(dt)) { vecs(index) = new LongVec(columnSize) @@ -98,6 +98,8 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, src.close() } } + // close bufferedBatch + bufferedBatch.foreach(batch => batch.close()) } private def flush(): Unit = { @@ -132,7 +134,13 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, val batch: ColumnarBatch = iter.next() val input: Array[Vec] = transColBatchToOmniVecs(batch) val vecBatch = new VecBatch(input, batch.numRows()) - buffer(vecBatch) + if (vecBatch.getRowCount > mergedBatchThreshold) { + flush() + outputQueue.enqueue(vecBatch) + bypassVecBatchs += 1 + } else { + buffer(vecBatch) + } } if (outputQueue.isEmpty && bufferedVecBatch.isEmpty) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala index 6012da931bb3b93ef8a3e6690d42ba3d1e4949e0..946c90a9baf346dc4e47253ced50a53def22374b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.execution.util -import nova.hetu.omniruntime.vector.VecAllocator - +import nova.hetu.omniruntime.memory +import nova.hetu.omniruntime.memory.MemoryManager import org.apache.spark.{SparkEnv, TaskContext} object SparkMemoryUtils { private val max: Long = SparkEnv.get.conf.getSizeAsBytes("spark.memory.offHeap.size", "1g") - VecAllocator.setRootAllocatorLimit(max) + MemoryManager.setGlobalMemoryLimit(max) def init(): Unit = {} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala new file mode 100644 index 0000000000000000000000000000000000000000..0503b2b7b684537f5191585cccf8b55cf50997d8 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.hadoop.hive.common.StatsSetupConst + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.CastSupport +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.DataSourceStrategy + +/** + * Prune hive table partitions using partition filters on [[HiveTableRelation]]. The pruned + * partitions will be kept in [[HiveTableRelation.prunedPartitions]], and the statistics of + * the hive table relation will be updated based on pruned partitions. + * + * This rule is executed in optimization phase, so the statistics can be updated before physical + * planning, which is useful for some spark strategy, e.g. + * [[org.apache.spark.sql.execution.SparkStrategies.JoinSelection]]. + * + * TODO: merge this with PruneFileSourcePartitions after we completely make hive as a data source. + */ +private[sql] class PruneHiveTablePartitions(session: SparkSession) + extends Rule[LogicalPlan] with CastSupport with PredicateHelper { + + /** + * Extract the partition filters from the filters on the table. + */ + private def getPartitionKeyFilters( + filters: Seq[Expression], + relation: HiveTableRelation): ExpressionSet = { + val normalizedFilters = DataSourceStrategy.normalizeExprs( + filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), relation.output) + val partitionColumnSet = AttributeSet(relation.partitionCols) + ExpressionSet( + normalizedFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionColumnSet))) + } + + /** + * Prune the hive table using filters on the partitions of the table. + */ + private def prunePartitions( + relation: HiveTableRelation, + partitionFilters: ExpressionSet): Seq[CatalogTablePartition] = { + if (conf.metastorePartitionPruning) { + session.sessionState.catalog.listPartitionsByFilter( + relation.tableMeta.identifier, partitionFilters.toSeq) + } else { + ExternalCatalogUtils.prunePartitionsByFilter(relation.tableMeta, + session.sessionState.catalog.listPartitions(relation.tableMeta.identifier), + partitionFilters.toSeq, conf.sessionLocalTimeZone) + } + } + + /** + * Update the statistics of the table. + */ + private def updateTableMeta( + relation: HiveTableRelation, + prunedPartitions: Seq[CatalogTablePartition], + partitionKeyFilters: ExpressionSet): CatalogTable = { + val sizeOfPartitions = prunedPartitions.map { partition => + val rawDataSize = partition.parameters.get(StatsSetupConst.RAW_DATA_SIZE).map(_.toLong) + val totalSize = partition.parameters.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + if (rawDataSize.isDefined && rawDataSize.get > 0) { + rawDataSize.get + } else if (totalSize.isDefined && totalSize.get > 0L) { + totalSize.get + } else { + 0L + } + } + // Fix spark issue SPARK-34119(row 95-106) + if (sizeOfPartitions.forall(_ > 0)) { + val filteredStats = + FilterEstimation(Filter(partitionKeyFilters.reduce(And), relation)).estimate + val colStats = filteredStats.map(_.attributeStats.map { case (attr, colStat) => + (attr.name, colStat.toCatalogColumnStat(attr.name, attr.dataType)) + }) + relation.tableMeta.copy( + stats = Some(CatalogStatistics( + sizeInBytes = BigInt(sizeOfPartitions.sum), + rowCount = filteredStats.flatMap(_.rowCount), + colStats = colStats.getOrElse(Map.empty)))) + } else { + relation.tableMeta + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case op @ PhysicalOperation(projections, filters, relation: HiveTableRelation) + if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty => + val partitionKeyFilters = getPartitionKeyFilters(filters, relation) + if (partitionKeyFilters.nonEmpty) { + val newPartitions = prunePartitions(relation, partitionKeyFilters) + // Fix spark issue SPARK-34119(row 117) + val newTableMeta = updateTableMeta(relation, newPartitions, partitionKeyFilters) + val newRelation = relation.copy( + tableMeta = newTableMeta, prunedPartitions = Some(newPartitions)) + // Keep partition filters so that they are visible in physical planning + Project(projections, Filter(filters.reduceLeft(And), newRelation)) + } else { + op + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/types/ColumnarBatchSupportUtil.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/types/ColumnarBatchSupportUtil.scala index cc3763164f211cee9083395d2335f65c2a286c91..bb31d7f82bd171d447502819309b045cd1a90db5 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/types/ColumnarBatchSupportUtil.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/types/ColumnarBatchSupportUtil.scala @@ -15,28 +15,31 @@ * limitations under the License. */ - package org.apache.spark.sql.types +package org.apache.spark.sql.types import org.apache.spark.sql.execution.FileSourceScanExec - import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat - import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.internal.SQLConf - object ColumnarBatchSupportUtil { - def checkColumnarBatchSupport(conf: SQLConf, plan: FileSourceScanExec): Boolean = { - val isSupportFormat: Boolean = { - plan.relation.fileFormat match { - case _: OrcFileFormat => - conf.orcVectorizedReaderEnabled - case _ => - false - } - } - val supportBatchReader: Boolean = { - val partitionSchema = plan.relation.partitionSchema - val resultSchema = StructType(plan.requiredSchema.fields ++ partitionSchema.fields) - conf.orcVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) - } - supportBatchReader && isSupportFormat - } - } +object ColumnarBatchSupportUtil { + def checkColumnarBatchSupport(conf: SQLConf, plan: FileSourceScanExec): Boolean = { + val isSupportFormat: Boolean = { + plan.relation.fileFormat match { + case _: OrcFileFormat => + conf.orcVectorizedReaderEnabled + case _: ParquetFileFormat => + conf.parquetVectorizedReaderEnabled + case _ => + false + } + } + val supportBatchReader: Boolean = { + val partitionSchema = plan.relation.partitionSchema + val resultSchema = StructType(plan.requiredSchema.fields ++ partitionSchema.fields) + (conf.orcVectorizedReaderEnabled || conf.parquetVectorizedReaderEnabled) && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + } + supportBatchReader && isSupportFormat + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleTest.java index 74fccca66fad64dac9c96ae5f60591de40e92012..8be5702dfbabc5bc847e4ebe547d1d4dfa243e6f 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleTest.java @@ -141,7 +141,7 @@ abstract class ColumnShuffleTest { } case OMNI_VARCHAR: case OMNI_CHAR: { - tmpVec = new VarcharVec(rowNum * 16, rowNum); + tmpVec = new VarcharVec(rowNum); for (int j = 0; j < rowNum; j++) { ((VarcharVec)tmpVec).set(j, ("VAR_" + (j + 1) + "_END").getBytes(StandardCharsets.UTF_8)); if (mixHalfNull && (j % 2) == 0) { @@ -196,7 +196,7 @@ abstract class ColumnShuffleTest { public List buildValChar(int pid, String varChar) { IntVec c0 = new IntVec(1); - VarcharVec c1 = new VarcharVec(8, 1); + VarcharVec c1 = new VarcharVec(1); c0.set(0, pid); c1.set(0, varChar.getBytes(StandardCharsets.UTF_8)); List columns = new ArrayList<>(); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReaderTest.java new file mode 100644 index 0000000000000000000000000000000000000000..5996413555c00cd1dedc2fe81bf50da5efe3c097 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReaderTest.java @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.jni; + +import junit.framework.TestCase; +import nova.hetu.omniruntime.vector.*; +import org.junit.After; +import org.junit.Before; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +@FixMethodOrder(value = MethodSorters.NAME_ASCENDING) +public class ParquetColumnarBatchJniReaderTest extends TestCase { + private ParquetColumnarBatchJniReader parquetColumnarBatchJniReader; + + private Vec[] vecs; + + @Before + public void setUp() throws Exception { + parquetColumnarBatchJniReader = new ParquetColumnarBatchJniReader(); + + List rowGroupIndices = new ArrayList<>(); + rowGroupIndices.add(0); + List columnIndices = new ArrayList<>(); + Collections.addAll(columnIndices, 0, 1, 3, 6, 7, 8, 9, 10, 12); + File file = new File("../cpp/test/tablescan/resources/parquet_data_all_type"); + String path = file.getAbsolutePath(); + parquetColumnarBatchJniReader.initializeReaderJava(path, 100000, rowGroupIndices, columnIndices, "root@sample"); + vecs = new Vec[9]; + } + + @After + public void tearDown() throws Exception { + parquetColumnarBatchJniReader.close(); + for (Vec vec : vecs) { + vec.close(); + } + } + + @Test + public void testRead() { + long num = parquetColumnarBatchJniReader.next(vecs); + assertTrue(num == 1); + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/parquetsrc/date_dim.parquet b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/parquetsrc/date_dim.parquet new file mode 100644 index 0000000000000000000000000000000000000000..a41dc76ea1b824b9ba30245a2d5b0069ff756294 Binary files /dev/null and b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/parquetsrc/date_dim.parquet differ diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptorSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptorSuite.scala index d60c544d765abb48c58d2758319f53fbfe8a8e3b..a4131e3ef869c301b56a72b48f3d2994884241f3 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptorSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptorSuite.scala @@ -248,6 +248,10 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { checkJsonExprRewrite("{\"exprType\":\"FUNCTION\",\"returnType\":1,\"function_name\":\"abs\"," + " \"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}]}", Abs(allAttribute(0))) + + checkJsonExprRewrite("{\"exprType\":\"FUNCTION\",\"returnType\":1,\"function_name\":\"round\"," + + " \"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0},{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":2}]}", + Round(allAttribute(0), Literal(2))) } protected def checkExpressionRewrite(expected: Any, expression: Expression): Unit = { @@ -272,45 +276,13 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { val cnAttribute = Seq(AttributeReference("char_1", StringType)(), AttributeReference("char_20", StringType)(), AttributeReference("varchar_1", StringType)(), AttributeReference("varchar_20", StringType)()) - val like = Like(cnAttribute(2), Literal("我_"), '\\'); - val likeResult = procLikeExpression(like, getExprIdMap(cnAttribute)) - val likeExp = "{\"exprType\":\"FUNCTION\",\"returnType\":4,\"function_name\":\"LIKE\", \"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000}, {\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"^我.$\",\"width\":4}]}" - if (!likeExp.equals(likeResult)) { - fail(s"expression($like) not match with expected value:$likeExp," + - s"running value:$likeResult") - } - - val startsWith = StartsWith(cnAttribute(2), Literal("我")); - val startsWithResult = procLikeExpression(startsWith, getExprIdMap(cnAttribute)) - val startsWithExp = "{\"exprType\":\"FUNCTION\",\"returnType\":4,\"function_name\":\"LIKE\", \"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000}, {\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"^我.*$\",\"width\":5}]}" - if (!startsWithExp.equals(startsWithResult)) { - fail(s"expression($startsWith) not match with expected value:$startsWithExp," + - s"running value:$startsWithResult") - } - - val endsWith = EndsWith(cnAttribute(2), Literal("我")); - val endsWithResult = procLikeExpression(endsWith, getExprIdMap(cnAttribute)) - val endsWithExp = "{\"exprType\":\"FUNCTION\",\"returnType\":4,\"function_name\":\"LIKE\", \"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000}, {\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"^.*我$\",\"width\":5}]}" - if (!endsWithExp.equals(endsWithResult)) { - fail(s"expression($endsWith) not match with expected value:$endsWithExp," + - s"running value:$endsWithResult") - } - - val contains = Contains(cnAttribute(2), Literal("我")); - val containsResult = procLikeExpression(contains, getExprIdMap(cnAttribute)) - val containsExp = "{\"exprType\":\"FUNCTION\",\"returnType\":4,\"function_name\":\"LIKE\", \"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000}, {\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"^.*我.*$\",\"width\":7}]}" - if (!containsExp.equals(containsResult)) { - fail(s"expression($contains) not match with expected value:$containsExp," + - s"running value:$containsResult") - } - val t1 = new Tuple2(Not(EqualTo(cnAttribute(0), Literal("新"))), Not(EqualTo(cnAttribute(1), Literal("官方爸爸")))) val t2 = new Tuple2(Not(EqualTo(cnAttribute(2), Literal("爱你三千遍"))), Not(EqualTo(cnAttribute(2), Literal("新")))) val branch = Seq(t1, t2) val elseValue = Some(Not(EqualTo(cnAttribute(3), Literal("啊水水水水")))) val caseWhen = CaseWhen(branch, elseValue); val caseWhenResult = rewriteToOmniJsonExpressionLiteral(caseWhen, getExprIdMap(cnAttribute)) - val caseWhenExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"官方爸爸\",\"width\":4}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"爱你三千遍\",\"width\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_false\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":3,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"啊水水水水\",\"width\":5}}}}" + val caseWhenExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"官方爸爸\",\"width\":4}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"爱你三千遍\",\"width\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_false\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":3,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"啊水水水水\",\"width\":5}}}}" if (!caseWhenExp.equals(caseWhenResult)) { fail(s"expression($caseWhen) not match with expected value:$caseWhenExp," + s"running value:$caseWhenResult") @@ -318,7 +290,7 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { val isNull = IsNull(cnAttribute(0)); val isNullResult = rewriteToOmniJsonExpressionLiteral(isNull, getExprIdMap(cnAttribute)) - val isNullExp = "{\"exprType\":\"IS_NULL\",\"returnType\":4,\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":2000}]}" + val isNullExp = "{\"exprType\":\"IS_NULL\",\"returnType\":4,\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50}]}" if (!isNullExp.equals(isNullResult)) { fail(s"expression($isNull) not match with expected value:$isNullExp," + s"running value:$isNullResult") @@ -327,7 +299,7 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { val children = Seq(cnAttribute(0), cnAttribute(1)) val coalesce = Coalesce(children); val coalesceResult = rewriteToOmniJsonExpressionLiteral(coalesce, getExprIdMap(cnAttribute)) - val coalesceExp = "{\"exprType\":\"COALESCE\",\"returnType\":15,\"width\":2000, \"value1\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":2000},\"value2\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":2000}}" + val coalesceExp = "{\"exprType\":\"COALESCE\",\"returnType\":15,\"width\":50, \"value1\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50},\"value2\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":50}}" if (!coalesceExp.equals(coalesceResult)) { fail(s"expression($coalesce) not match with expected value:$coalesceExp," + s"running value:$coalesceResult") @@ -356,7 +328,7 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { val elseValue = Some(Not(EqualTo(caseWhenAttribute(3), Literal("啊水水水水")))) val expression = CaseWhen(branch, elseValue); val runResult = procCaseWhenExpression(expression, getExprIdMap(caseWhenAttribute)) - val filterExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"官方爸爸\",\"width\":4}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"爱你三千遍\",\"width\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_false\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":3,\"width\":2000},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"啊水水水水\",\"width\":5}}}}" + val filterExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"官方爸爸\",\"width\":4}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"爱你三千遍\",\"width\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_false\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":3,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"啊水水水水\",\"width\":5}}}}" if (!filterExp.equals(runResult)) { fail(s"expression($expression) not match with expected value:$filterExp," + s"running value:$runResult") diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala index 00adf145979e33f7dd7b1c49873fd72cdff18756..998791c8c0b11499a6cd92ca3ffdffc1a7f9b0fc 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala @@ -328,7 +328,7 @@ object ColumnarShuffleWriterSuite { def initOmniColumnVarcharVector(values: Array[java.lang.String]): OmniColumnVector = { val length = values.length - val vecTmp = new VarcharVec(1024, length) + val vecTmp = new VarcharVec(length) (0 until length).foreach { i => if (values(i) != null) { vecTmp.set(i, values(i).getBytes()) diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..4ee27c44db0a1f09a0252869ae3e3277fd01cd2b --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala @@ -0,0 +1,218 @@ +/* + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.{Cross, Inner, InnerLike, PlanTest} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class HeuristicJoinReorderSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueryAliases) :: + Batch("Filter Pushdown", FixedPoint(100), + CombineFilters, + PushPredicateThroughNonJoin, + BooleanSimplification, + ReorderJoin, + PushPredicateThroughJoin, + ColumnPruning, + RemoveNoopOperators, + CollapseProject) :: + Batch("Heuristic Join Reorder", FixedPoint(1), + DelayCartesianProduct, + HeuristicJoinReorder, + PushDownPredicates, + ColumnPruning, + CollapseProject, + RemoveNoopOperators) :: Nil + } + + private val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + private val testRelation1 = LocalRelation('d.int) + + private val IOV_ALARM_DAILY = LocalRelation('DID.int, 'DATA_TIME.int) + private val DETAILS = LocalRelation('CODE.int) + private val IOV_BIZ_CAR_INFO_ALL2 = LocalRelation('DID.int, 'CBM_MAG_COMPANY_ID.string) + private val IOV_BIZ_CAN_BUS_TYPE = LocalRelation('CODE.int, 'SITE.int, 'ID.int) + private val CBM_COM_DDIC_CONTENT = LocalRelation('ID.int, 'CBM_COM_DDIC_TYPE_ID.int) + private val CBM_COM_DDIC_TYPE = LocalRelation('ID.int, 'CODE.string) + private val IOV_BIZ_L_OPTION_RANK_TYPE = + LocalRelation('IOV_BIZ_CAN_BUS_TYPE_ID.int, 'CBM_COM_OPTION_RANK_ID.int) + private val CBM_COM_OPTION_RANK = LocalRelation('ID.int, 'CBM_MAG_COMPANY_ID.int) + + test("reorder inner joins") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + val z = testRelation.subquery('z) + + val queryAnswers = Seq( + ( + x.join(y).join(z).where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)), + x.join(z, condition = Some("x.b".attr === "z.b".attr)) + .join(y, condition = Some("y.d".attr === "z.a".attr)) + .select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*) + ), + ( + x.join(y, Cross).join(z, Cross) + .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)), + x.join(z, Cross, Some("x.b".attr === "z.b".attr)) + .join(y, Cross, Some("y.d".attr === "z.a".attr)) + .select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*) + ), + ( + x.join(y, Inner).join(z, Cross).where("x.b".attr === "z.a".attr), + x.join(z, Cross, Some("x.b".attr === "z.a".attr)).join(y, Inner) + .select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*) + ) + ) + + queryAnswers foreach { queryAnswerPair => + val optimized = Optimize.execute(queryAnswerPair._1.analyze) + comparePlans(optimized, queryAnswerPair._2.analyze) + } + } + + test("extract filters and joins") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + val z = testRelation.subquery('z) + + def testExtract(plan: LogicalPlan, + expected: Option[(Seq[LogicalPlan], Seq[Expression])]): Unit = { + val expectedNoCross = expected map { + seq_pair => { + val plans = seq_pair._1 + val noCartesian = plans map { plan => (plan, Inner) } + (noCartesian, seq_pair._2) + } + } + testExtractCheckCross(plan, expectedNoCross) + } + + def testExtractCheckCross(plan: LogicalPlan, expected: Option[(Seq[(LogicalPlan, InnerLike)], + Seq[Expression])]): Unit = { + assert( + ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1, e._2))) + } + + testExtract(x, None) + testExtract(x.where("x.b".attr === 1), None) + testExtract(x.join(y), Some((Seq(x, y), Seq()))) + testExtract(x.join(y, condition = Some("x.b".attr === "y.d".attr)), + Some((Seq(x, y), Seq("x.b".attr === "y.d".attr)))) + testExtract(x.join(y).where("x.b".attr === "y.d".attr), + Some((Seq(x, y), Seq("x.b".attr === "y.d".attr)))) + testExtract(x.join(y).join(z), Some((Seq(x, y, z), Seq()))) + testExtract(x.join(y).where("x.b".attr === "y.d".attr).join(z), + Some((Seq(x, y, z), Seq("x.b".attr === "y.d".attr)))) + testExtract(x.join(y).join(x.join(z)), Some((Seq(x, y, x.join(z)), Seq()))) + testExtract(x.join(y).join(x.join(z)).where("x.b".attr === "y.d".attr), + Some((Seq(x, y, x.join(z)), Seq("x.b".attr === "y.d".attr)))) + + testExtractCheckCross(x.join(y, Cross), Some((Seq((x, Cross), (y, Cross)), Seq()))) + testExtractCheckCross(x.join(y, Cross).join(z, Cross), + Some((Seq((x, Cross), (y, Cross), (z, Cross)), Seq()))) + testExtractCheckCross(x.join(y, Cross, Some("x.b".attr === "y.d".attr)).join(z, Cross), + Some((Seq((x, Cross), (y, Cross), (z, Cross)), Seq("x.b".attr === "y.d".attr)))) + testExtractCheckCross(x.join(y, Inner, Some("x.b".attr === "y.d".attr)).join(z, Cross), + Some((Seq((x, Inner), (y, Inner), (z, Cross)), Seq("x.b".attr === "y.d".attr)))) + testExtractCheckCross(x.join(y, Cross, Some("x.b".attr === "y.d".attr)).join(z, Inner), + Some((Seq((x, Cross), (y, Cross), (z, Inner)), Seq("x.b".attr === "y.d".attr)))) + } + + test("DelayCartesianProduct: basic scenario") { + val T = IOV_ALARM_DAILY.subquery('T) + val DT = DETAILS.subquery('DT) + val C = IOV_BIZ_CAR_INFO_ALL2.subquery('C) + val CAT = IOV_BIZ_CAN_BUS_TYPE.subquery('CAT) + val DDIC = CBM_COM_DDIC_CONTENT.subquery('DDIC) + val DDICT = CBM_COM_DDIC_TYPE.subquery('DDICT) + val OPRL = IOV_BIZ_L_OPTION_RANK_TYPE.subquery('OPRL) + val OPR = CBM_COM_OPTION_RANK.subquery('OPR) + + val query = T.join(DT, condition = None) + .join(C, condition = Some("C.DID".attr === "T.DID".attr)) + .join(CAT, condition = Some("CAT.CODE".attr === "DT.CODE".attr)) + .join(DDIC, condition = Some("DDIC.ID".attr === "CAT.SITE".attr)) + .join(DDICT, condition = Some("DDICT.ID".attr === "DDIC.CBM_COM_DDIC_TYPE_ID".attr)) + .join(OPRL, condition = Some("OPRL.IOV_BIZ_CAN_BUS_TYPE_ID".attr === "CAT.ID".attr)) + .join(OPR, condition = Some("OPR.ID".attr === "OPRL.CBM_COM_OPTION_RANK_ID".attr)) + .where(("T.DATA_TIME".attr < 100) + && ("C.CBM_MAG_COMPANY_ID".attr like "%500%") + && ("OPR.CBM_MAG_COMPANY_ID".attr === -1) + && ("DDICT.CODE".attr === "2004")) + val optimized = Optimize.execute(query.analyze) + + val clique1 = T.where("T.DATA_TIME".attr < 100) + .join(C.where("C.CBM_MAG_COMPANY_ID".attr like "%500%"), + condition = Some("C.DID".attr === "T.DID".attr)) + val clique2 = DT.join(CAT, condition = Some("CAT.CODE".attr === "DT.CODE".attr)) + .join(DDIC, condition = Some("DDIC.ID".attr === "CAT.SITE".attr)) + .join(DDICT.where("DDICT.CODE".attr === "2004"), + condition = Some("DDICT.ID".attr === "DDIC.CBM_COM_DDIC_TYPE_ID".attr)) + .join(OPRL, condition = Some("OPRL.IOV_BIZ_CAN_BUS_TYPE_ID".attr === "CAT.ID".attr)) + .join(OPR.where("OPR.CBM_MAG_COMPANY_ID".attr === -1), + condition = Some("OPR.ID".attr === "OPRL.CBM_COM_OPTION_RANK_ID".attr)) + val expected = clique1.join(clique2, condition = None) + .select(Seq("T.DID", "T.DATA_TIME", "DT.CODE", "C.DID", "C.CBM_MAG_COMPANY_ID", "CAT.CODE", + "CAT.SITE", "CAT.ID", "DDIC.ID", "DDIC.CBM_COM_DDIC_TYPE_ID", "DDICT.ID", "DDICT.CODE", + "OPRL.IOV_BIZ_CAN_BUS_TYPE_ID", "OPRL.CBM_COM_OPTION_RANK_ID", "OPR.ID", + "OPR.CBM_MAG_COMPANY_ID").map(_.attr): _*).analyze + + comparePlans(optimized, expected) + } + + test("DelayCartesianProduct: more than two cliques") { + val big1 = testRelation.subquery('big1) + val big2 = testRelation.subquery('big2) + val big3 = testRelation.subquery('big3) + val small1 = testRelation1.subquery('small1) + val small2 = testRelation1.subquery('small2) + val small3 = testRelation1.subquery('small3) + val small4 = testRelation1.subquery('small4) + + val query = big1.join(big2, condition = None) + .join(big3, condition = None) + .join(small1, condition = Some("big1.a".attr === "small1.d".attr)) + .join(small2, condition = Some("big2.b".attr === "small2.d".attr)) + .join(small3, condition = Some("big3.a".attr === "small3.d".attr)) + .join(small4, condition = Some("big3.b".attr === "small4.d".attr)) + val optimized = Optimize.execute(query.analyze) + + val clique1 = big1.join(small1, condition = Some("big1.a".attr === "small1.d".attr)) + val clique2 = big2.join(small2, condition = Some("big2.b".attr === "small2.d".attr)) + val clique3 = big3.join(small3, condition = Some("big3.a".attr === "small3.d".attr)) + .join(small4, condition = Some("big3.b".attr === "small4.d".attr)) + val expected = clique1.join(clique2, condition = None) + .join(clique3, condition = None) + .select(Seq("big1.a", "big1.b", "big1.c", "big2.a", "big2.b", "big2.c", "big3.a", + "big3.b", "big3.c", "small1.d", "small2.d", "small3.d", "small4.d").map(_.attr): _*) + .analyze + + comparePlans(optimized, expected) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExecSuite.scala index f362d85e5feda8c32e7075b8742bcc19a45b30a9..cc724b31ad30b1e4f0d152194bd927ed71d7b3f3 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExecSuite.scala @@ -18,33 +18,82 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, StructType} class ColumnarExecSuite extends ColumnarSparkPlanTest { - private lazy val df = spark.createDataFrame( - sparkContext.parallelize(Seq( - Row(1, 2.0, false), - Row(1, 2.0, false), - Row(2, 1.0, false), - Row(null, null, false), - Row(null, 5.0, false), - Row(6, null, false) - )), new StructType().add("a", IntegerType).add("b", DoubleType) - .add("c", BooleanType)) + private var dealer: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + + dealer = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0, false), + Row(1, 2.0, false), + Row(2, 1.0, false), + Row(null, null, false), + Row(null, 5.0, false), + Row(6, null, false) + )), new StructType().add("a", IntegerType).add("b", DoubleType) + .add("c", BooleanType)) + dealer.createOrReplaceTempView("dealer") + } test("validate columnar transfer exec happened") { - val res = df.filter("a > 1") - print(res.queryExecution.executedPlan) - assert(res.queryExecution.executedPlan.find(_.isInstanceOf[RowToOmniColumnarExec]).isDefined, s"RowToOmniColumnarExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + val sql1 = "SELECT a + 1 FROM dealer" + assertColumnarToRowOmniAndSparkResultEqual(sql1, false) } - test("validate data type convert") { - val res = df.filter("a > 1") - print(res.queryExecution.executedPlan) + test("spark limit with columnarToRow as child") { + + // fetch parital + val sql1 = "select * from (select a, b+2 from dealer order by a, b+2) limit 2" + assertColumnarToRowOmniAndSparkResultEqual(sql1, false) + + // fetch all + val sql2 = "select a, b+2 from dealer limit 6" + assertColumnarToRowOmniAndSparkResultEqual(sql2, true) + + // fetch all + val sql3 = "select a, b+2 from dealer limit 10" + assertColumnarToRowOmniAndSparkResultEqual(sql3, true) + + // fetch parital + val sql4 = "select a, b+2 from dealer order by a limit 2" + assertColumnarToRowOmniAndSparkResultEqual(sql4, false) + + // fetch all + val sql5 = "select a, b+2 from dealer order by a limit 6" + assertColumnarToRowOmniAndSparkResultEqual(sql5, false) + + // fetch all + val sql6 = "select a, b+2 from dealer order by a limit 10" + assertColumnarToRowOmniAndSparkResultEqual(sql6, false) + } + + private def assertColumnarToRowOmniAndSparkResultEqual(sql: String, mayPartialFetch: Boolean = true): Unit = { + + spark.conf.set("spark.omni.sql.columnar.takeOrderedAndProject", true) + spark.conf.set("spark.omni.sql.columnar.project", true) + val omniResult = spark.sql(sql) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[OmniColumnarToRowExec]).isDefined, + s"SQL:${sql}\n@OmniEnv no OmniColumnarToRowExec,omniPlan:${omniPlan}") + assert(omniPlan.find(_.isInstanceOf[OmniColumnarToRowExec]).get + .asInstanceOf[OmniColumnarToRowExec].mayPartialFetch == mayPartialFetch, + s"SQL:${sql}\n@OmniEnv OmniColumnarToRowExec mayPartialFetch value wrong:${omniPlan}") + + spark.conf.set("spark.omni.sql.columnar.takeOrderedAndProject", false) + spark.conf.set("spark.omni.sql.columnar.project", false) + val sparkResult = spark.sql(sql) + val sparkPlan = sparkResult.queryExecution.executedPlan + assert(sparkPlan.find(_.isInstanceOf[OmniColumnarToRowExec]).isEmpty, + s"SQL:${sql}\n@SparkEnv have OmniColumnarToRowExec,sparkPlan:${sparkPlan}") - checkAnswer( - df.filter("a > 1"), - Row(2, 1.0, false) :: Row(6, null, false) :: Nil) + assert(omniResult.except(sparkResult).isEmpty, + s"SQL:${sql}\nomniResult:${omniResult.show()}\nsparkResult:${sparkResult.show()}\n") + spark.conf.set("spark.omni.sql.columnar.takeOrderedAndProject", true) + spark.conf.set("spark.omni.sql.columnar.project", true) } } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarFileScanExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarFileScanExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..ec86f55013f013185880893cfff0b81ff6c89270 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarFileScanExecSuite.scala @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{DataFrame, Row, SparkSession} + +import java.io.File + +class ColumnarFileScanExecSuite extends ColumnarSparkPlanTest { + private var load: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + } + + test("validate columnar filescan exec for parquet happened") { + val file = new File("src/test/java/com/huawei/boostkit/spark/jni/parquetsrc/date_dim.parquet") + val path = file.getAbsolutePath + load = spark.read.parquet(path) + load.createOrReplaceTempView("parquet_scan_table") + val res = spark.sql("select * from parquet_scan_table") + assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarFileSourceScanExec]).isDefined, + s"ColumnarFileSourceScanExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + res.show() + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateDistinctOperatorSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateDistinctOperatorSuite.scala index 11795954d94cd29444878dfbf369255cbcf0a164..57d022c1fdcfedd35c9ef00ff746aee25d16d2ab 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateDistinctOperatorSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateDistinctOperatorSuite.scala @@ -67,6 +67,11 @@ class ColumnarHashAggregateDistinctOperatorSuite extends ColumnarSparkPlanTest { dealer_decimal.createOrReplaceTempView("dealer_decimal") } + test("check columnar hashAgg filter result with distinct") { + val sql1 = "select id, count(distinct car_model) filter (where quantity is not null) from dealer group by id" + assertHashAggregateExecOmniAndSparkResultEqual(sql1) + } + test("Test HashAgg with 1 distinct:") { val sql1 = "SELECT car_model, count(DISTINCT quantity) AS count FROM dealer" + " GROUP BY car_model;" @@ -164,7 +169,7 @@ class ColumnarHashAggregateDistinctOperatorSuite extends ColumnarSparkPlanTest { test("Test HashAgg with decimal distinct:") { val sql1 = "select car_model, avg(DISTINCT quantity_dec8_2), count(DISTINCT city) from dealer_decimal" + " group by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql1) + assertHashAggregateExecOmniAndSparkResultEqual(sql1, hashAggExecFullReplace = false) val sql2 = "select car_model, min(id), sum(DISTINCT quantity_dec8_2), count(DISTINCT city) from dealer_decimal" + " group by car_model;" @@ -178,7 +183,7 @@ class ColumnarHashAggregateDistinctOperatorSuite extends ColumnarSparkPlanTest { val sql4 = "select car_model, avg(DISTINCT quantity_dec11_2), count(DISTINCT city) from dealer_decimal" + " group by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql4) + assertHashAggregateExecOmniAndSparkResultEqual(sql4, hashAggExecFullReplace = false) val sql5 = "select car_model, min(id), sum(DISTINCT quantity_dec11_2), count(DISTINCT city) from dealer_decimal" + " group by car_model;" @@ -192,11 +197,11 @@ class ColumnarHashAggregateDistinctOperatorSuite extends ColumnarSparkPlanTest { val sql7 = "select car_model, count(DISTINCT quantity_dec8_2), avg(DISTINCT quantity_dec8_2), sum(DISTINCT quantity_dec8_2) from dealer_decimal" + " group by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql7) + assertHashAggregateExecOmniAndSparkResultEqual(sql7, hashAggExecFullReplace = false) val sql8 = "select car_model, count(DISTINCT quantity_dec11_2), avg(DISTINCT quantity_dec11_2), sum(DISTINCT quantity_dec11_2) from dealer_decimal" + " group by car_model;" - assertHashAggregateExecOmniAndSparkResultEqual(sql8) + assertHashAggregateExecOmniAndSparkResultEqual(sql8, hashAggExecFullReplace = false) } test("Test HashAgg with multi distinct + multi without distinct + order by:") { diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala index 5c732d6b97eff3be3e46d6d3afda3c776444e6b0..e69ef0258e33642d0b264d8ac4576aecdbcc6af4 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.functions.{sum, count} +import org.apache.spark.sql.functions.{avg, count, first, max, min, sum} import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row} @@ -36,6 +36,16 @@ class ColumnarHashAggregateExecSuite extends ColumnarSparkPlanTest { Row(null, 5.0, 7L, "f") )), new StructType().add("a", IntegerType).add("b", DoubleType) .add("c", LongType).add("d", StringType)) + df.createOrReplaceTempView("df_tbl") + } + + test("check columnar hashAgg filter result") { + val res = spark.sql("select a, sum(b) filter (where c > 1) from df_tbl group by a") + assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + checkAnswer( + res, + Seq(Row(null, 5.0), Row(1, 2.0), Row(2, 1.0)) + ) } test("validate columnar hashAgg exec happened") { @@ -77,4 +87,69 @@ class ColumnarHashAggregateExecSuite extends ColumnarSparkPlanTest { Seq(Row(1, 2), Row(2, 1), Row(null, 2)) ) } + + test("test hashAgg null") { + var res = df.filter(df("a").equalTo(3)).groupBy("a").agg(sum("a")) + checkAnswer( + res, + Seq.empty + ) + res = df.filter(df("a").equalTo(3)).groupBy("a").agg(max("a")) + checkAnswer( + res, + Seq.empty + ) + res = df.filter(df("a").equalTo(3)).groupBy("a").agg(min("a")) + checkAnswer( + res, + Seq.empty + ) + res = df.filter(df("a").equalTo(3)).groupBy("a").agg(avg("a")) + checkAnswer( + res, + Seq.empty + ) + res = df.filter(df("a").equalTo(3)).groupBy("a").agg(first("a")) + checkAnswer( + res, + Seq.empty + ) + res = df.filter(df("a").equalTo(3)).groupBy("a").agg(count("a")) + checkAnswer( + res, + Seq.empty + ) + } + test("test agg null") { + var res = df.filter(df("a").equalTo(3)).agg(sum("a")) + checkAnswer( + res, + Seq(Row(null)) + ) + res = df.filter(df("a").equalTo(3)).agg(max("a")) + checkAnswer( + res, + Seq(Row(null)) + ) + res = df.filter(df("a").equalTo(3)).agg(min("a")) + checkAnswer( + res, + Seq(Row(null)) + ) + res = df.filter(df("a").equalTo(3)).agg(avg("a")) + checkAnswer( + res, + Seq(Row(null)) + ) + res = df.filter(df("a").equalTo(3)).agg(first("a")) + checkAnswer( + res, + Seq(Row(null)) + ) + res = df.filter(df("a").equalTo(3)).agg(count("a")) + checkAnswer( + res, + Seq(Row(0)) + ) + } } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala index 4add4dd8077b5b4814cd993aae792677c99e260a..66acf366a70c1ee273f210fe13993ee17cc64845 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala @@ -21,9 +21,10 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.BuildRight -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftOuter} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ColumnarBroadcastHashJoinExec, ColumnarShuffledHashJoinExec, ColumnarSortMergeJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} // refer to joins package class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { @@ -34,6 +35,8 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { private var right: DataFrame = _ private var leftWithNull: DataFrame = _ private var rightWithNull: DataFrame = _ + private var person_test: DataFrame = _ + private var order_test: DataFrame = _ protected override def beforeAll(): Unit = { super.beforeAll() @@ -64,6 +67,29 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { (" add", null, 1, null), (" yeah ", null, null, 4.0) ).toDF("a", "b", "c", "d") + + person_test = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(3, "Carter"), + Row(1, "Adams"), + Row(2, "Bush") + )), new StructType() + .add("id_p", IntegerType) + .add("name", StringType)) + person_test.createOrReplaceTempView("person_test") + + order_test = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(5, 34764, 65), + Row(1, 77895, 3), + Row(2, 44678, 3), + Row(4, 24562, 1), + Row(3, 22456, 1) + )), new StructType() + .add("id_o", IntegerType) + .add("order_no", IntegerType) + .add("id_p", IntegerType)) + order_test.createOrReplaceTempView("order_test") } test("validate columnar broadcastHashJoin exec happened") { @@ -90,45 +116,113 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { } test("columnar sortMergeJoin Inner Join is equal to native") { - val df = left.join(right.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(left.col("q").expr) - val rightKeys = Seq(right.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, Inner) + val enableFusionArr = Array(false, true) + for (item <- enableFusionArr) { + spark.conf.set("spark.omni.sql.columnar.sortMergeJoin.fusion", item) + val df = left.join(right.hint("mergejoin"), col("q") === col("c")) + val leftKeys = Seq(left.col("q").expr) + val rightKeys = Seq(right.col("c").expr) + checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, Inner) + } } test("columnar sortMergeJoin Inner Join is equal to native With NULL") { - val df = leftWithNull.join(rightWithNull.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(leftWithNull.col("q").expr) - val rightKeys = Seq(rightWithNull.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, Inner) + val enableFusionArr = Array(false, true) + for (item <- enableFusionArr) { + spark.conf.set("spark.omni.sql.columnar.sortMergeJoin.fusion", item) + val df = leftWithNull.join(rightWithNull.hint("mergejoin"), col("q") === col("c")) + val leftKeys = Seq(leftWithNull.col("q").expr) + val rightKeys = Seq(rightWithNull.col("c").expr) + checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, Inner) + } } test("columnar sortMergeJoin LeftOuter Join is equal to native") { - val df = left.join(right.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(left.col("q").expr) - val rightKeys = Seq(right.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, LeftOuter) + val enableFusionArr = Array(false, true) + for (item <- enableFusionArr) { + spark.conf.set("spark.omni.sql.columnar.sortMergeJoin.fusion", item) + val df = left.join(right.hint("mergejoin"), col("q") === col("c")) + val leftKeys = Seq(left.col("q").expr) + val rightKeys = Seq(right.col("c").expr) + checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, LeftOuter) + } } test("columnar sortMergeJoin LeftOuter Join is equal to native With NULL") { - val df = leftWithNull.join(rightWithNull.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(leftWithNull.col("q").expr) - val rightKeys = Seq(rightWithNull.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, LeftOuter) + val enableFusionArr = Array(false, true) + for (item <- enableFusionArr) { + spark.conf.set("spark.omni.sql.columnar.sortMergeJoin.fusion", item) + val df = leftWithNull.join(rightWithNull.hint("mergejoin"), col("q") === col("c")) + val leftKeys = Seq(leftWithNull.col("q").expr) + val rightKeys = Seq(rightWithNull.col("c").expr) + checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, LeftOuter) + } } test("columnar sortMergeJoin FullOuter Join is equal to native") { - val df = left.join(right.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(left.col("q").expr) - val rightKeys = Seq(right.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, FullOuter) + val enableFusionArr = Array(false, true) + for (item <- enableFusionArr) { + spark.conf.set("spark.omni.sql.columnar.sortMergeJoin.fusion", item) + val df = left.join(right.hint("mergejoin"), col("q") === col("c")) + val leftKeys = Seq(left.col("q").expr) + val rightKeys = Seq(right.col("c").expr) + checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, FullOuter) + } } test("columnar sortMergeJoin FullOuter Join is equal to native With NULL") { - val df = leftWithNull.join(rightWithNull.hint("mergejoin"), col("q") === col("c")) - val leftKeys = Seq(leftWithNull.col("q").expr) - val rightKeys = Seq(rightWithNull.col("c").expr) - checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, FullOuter) + val enableFusionArr = Array(false, true) + for (item <- enableFusionArr) { + spark.conf.set("spark.omni.sql.columnar.sortMergeJoin.fusion", item) + val df = leftWithNull.join(rightWithNull.hint("mergejoin"), col("q") === col("c")) + val leftKeys = Seq(leftWithNull.col("q").expr) + val rightKeys = Seq(rightWithNull.col("c").expr) + checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, FullOuter) + } + } + + test("columnar sortMergeJoin LeftSemi Join is equal to native") { + val enableFusionArr = Array(false, true) + for (item <- enableFusionArr) { + spark.conf.set("spark.omni.sql.columnar.sortMergeJoin.fusion", item) + val df = left.join(right.hint("mergejoin"), col("q") === col("c")) + val leftKeys = Seq(left.col("q").expr) + val rightKeys = Seq(right.col("c").expr) + checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, LeftSemi) + } + } + + test("columnar sortMergeJoin LeftSemi Join is equal to native With NULL") { + val enableFusionArr = Array(false, true) + for (item <- enableFusionArr) { + spark.conf.set("spark.omni.sql.columnar.sortMergeJoin.fusion", item) + val df = leftWithNull.join(rightWithNull.hint("mergejoin"), col("q") === col("c")) + val leftKeys = Seq(leftWithNull.col("q").expr) + val rightKeys = Seq(rightWithNull.col("c").expr) + checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, LeftSemi) + } + } + + test("columnar sortMergeJoin LeftAnti Join is equal to native") { + val enableFusionArr = Array(false, true) + for (item <- enableFusionArr) { + spark.conf.set("spark.omni.sql.columnar.sortMergeJoin.fusion", item) + val df = left.join(right.hint("mergejoin"), col("q") === col("c")) + val leftKeys = Seq(left.col("q").expr) + val rightKeys = Seq(right.col("c").expr) + checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, LeftAnti) + } + } + + test("columnar sortMergeJoin LeftAnti Join is equal to native With NULL") { + val enableFusionArr = Array(false, true) + for (item <- enableFusionArr) { + spark.conf.set("spark.omni.sql.columnar.sortMergeJoin.fusion", item) + val df = leftWithNull.join(rightWithNull.hint("mergejoin"), col("q") === col("c")) + val leftKeys = Seq(leftWithNull.col("q").expr) + val rightKeys = Seq(rightWithNull.col("c").expr) + checkThatPlansAgreeTemplateForSMJ(df, leftKeys, rightKeys, LeftAnti) + } } test("columnar broadcastHashJoin is equal to native with null") { @@ -139,6 +233,21 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { checkThatPlansAgreeTemplateForBHJ(df, leftKeys, rightKeys) } + test("columnar broadcastHashJoin LeftSemi Join is equal to native") { + val df = left.join(right.hint("broadcast"), col("q") === col("c")) + val leftKeys = Seq(left.col("q").expr) + val rightKeys = Seq(right.col("c").expr) + checkThatPlansAgreeTemplateForBHJ(df, leftKeys, rightKeys, LeftSemi) + } + + test("columnar broadcastHashJoin LeftSemi Join is equal to native with null") { + val df = leftWithNull.join(rightWithNull.hint("broadcast"), + col("q").isNotNull === col("c").isNotNull) + val leftKeys = Seq(leftWithNull.col("q").isNotNull.expr) + val rightKeys = Seq(rightWithNull.col("c").isNotNull.expr) + checkThatPlansAgreeTemplateForBHJ(df, leftKeys, rightKeys, LeftSemi) + } + def checkThatPlansAgreeTemplateForBHJ(df: DataFrame, leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType = Inner): Unit = { checkThatPlansAgree( @@ -224,6 +333,59 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { ), false) } + test("validate columnar shuffledHashJoin left semi join happened") { + val res = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "leftsemi") + assert( + res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffledHashJoinExec]).isDefined, + s"ColumnarShuffledHashJoinExec not happened," + + s" executedPlan as follows: \n${res.queryExecution.executedPlan}") + } + + test("columnar shuffledHashJoin left semi join is equal to native") { + val df = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "leftsemi") + checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( + Row("abc", "", 4, 2.0), + Row("", "Hello", 1, 1.0) + ), false) + } + + test("columnar shuffledHashJoin left semi join is equal to native with null") { + val df = leftWithNull.join(rightWithNull.hint("SHUFFLE_HASH"), + col("q") === col("c"), "leftsemi") + checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( + Row("abc", null, 4, 2.0) + ), false) + } + + test("validate columnar shuffledHashJoin left outer join happened") { + val res = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "leftouter") + assert( + res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffledHashJoinExec]).isDefined, + s"ColumnarShuffledHashJoinExec not happened," + + s" executedPlan as follows: \n${res.queryExecution.executedPlan}") + } + + test("columnar shuffledHashJoin left outer join is equal to native") { + val df = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "leftouter") + checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( + Row("abc", "", 4, 2.0, "abc", "", 4, 1.0), + Row(" yeah ", "yeah", 10, 8.0, null, null, null, null), + Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), + Row(" add", "World", 8, 3.0, null, null, null, null) + ), false) + } + + test("columnar shuffledHashJoin left outer join is equal to native with null") { + val df = leftWithNull.join(rightWithNull.hint("SHUFFLE_HASH"), + col("q") === col("c"), "leftouter") + checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( + Row("abc", null, 4, 2.0, "abc", "", 4, 1.0), + Row("", "Hello", null, 1.0, null, null, null, null), + Row(" yeah ", "yeah", 10, 8.0, null, null, null, null), + Row(" add", "World", 8, 3.0, null, null, null, null) + ), false) + } + test("ColumnarBroadcastHashJoin is not rolled back with not_equal filter expr") { val res = left.join(right.hint("broadcast"), left("a") <=> right("a")) assert( @@ -244,4 +406,250 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { None, child, child), sortAnswers = true) } -} \ No newline at end of file + + test("BroadcastHashJoin and project fusion test") { + val omniResult = person_test.join(order_test.hint("broadcast"), person_test("id_p") === order_test("id_p"), "leftouter") + .select(person_test("name"), order_test("order_no")) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, + s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") + checkAnswer(omniResult, _ => omniPlan, Seq( + Row("Carter", 44678), + Row("Carter", 77895), + Row("Adams", 22456), + Row("Adams", 24562), + Row("Bush", null) + ), false) + } + + test("BroadcastHashJoin and project fusion test for duplicate column") { + val omniResult = person_test.join(order_test.hint("broadcast"), person_test("id_p") === order_test("id_p"), "leftouter") + .select(person_test("name"), order_test("order_no"), order_test("id_p")) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, + s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") + checkAnswer(omniResult, _ => omniPlan, Seq( + Row("Carter", 44678, 3), + Row("Carter", 77895, 3), + Row("Adams", 22456, 1), + Row("Adams", 24562, 1), + Row("Bush", null, null) + ), false) + } + + test("BroadcastHashJoin and project fusion test for reorder columns") { + val omniResult = person_test.join(order_test.hint("broadcast"), person_test("id_p") === order_test("id_p"), "leftouter") + .select(order_test("order_no"), person_test("name"), order_test("id_p")) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, + s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") + checkAnswer(omniResult, _ => omniPlan, Seq( + Row(44678, "Carter", 3), + Row(77895, "Carter", 3), + Row(22456, "Adams", 1), + Row(24562, "Adams", 1), + Row(null, "Bush", null) + ), false) + } + + test("BroadcastHashJoin and project are not fused test") { + val omniResult = person_test.join(order_test.hint("broadcast"), person_test("id_p") === order_test("id_p"), "leftouter") + .select(order_test("order_no").plus(1), person_test("name")) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, + s"SQL:\n@OmniEnv have ColumnarProjectExec,omniPlan:${omniPlan}") + checkAnswer(omniResult, _ => omniPlan, Seq( + Row(44679, "Carter"), + Row(77896, "Carter"), + Row(22457, "Adams"), + Row(24563, "Adams"), + Row(null, "Bush") + ), false) + } + + test("BroadcastHashJoin and project fusion test for alias") { + val omniResult = person_test.join(order_test.hint("broadcast"), person_test("id_p") === order_test("id_p"), "leftouter") + .select(person_test("name").as("name1"), order_test("order_no").as("order_no1")) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, + s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") + checkAnswer(omniResult, _ => omniPlan, Seq( + Row("Carter", 44678), + Row("Carter", 77895), + Row("Adams", 22456), + Row("Adams", 24562), + Row("Bush", null) + ), false) + } + + test("validate columnar shuffledHashJoin left anti join happened") { + val res = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "leftanti") + assert( + res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffledHashJoinExec]).isDefined, + s"ColumnarShuffledHashJoinExec not happened," + + s" executedPlan as follows: \n${res.queryExecution.executedPlan}") + } + + test("columnar shuffledHashJoin left anti join is equal to native") { + val df = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "leftanti") + checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( + Row(" yeah ", "yeah", 10, 8.0), + Row(" add", "World", 8, 3.0) + ), false) + } + + test("columnar shuffledHashJoin left anti join is equal to native with null") { + val df = leftWithNull.join(rightWithNull.hint("SHUFFLE_HASH"), + col("q") === col("c"), "leftanti") + checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( + Row("", "Hello", null, 1.0), + Row(" yeah ", "yeah", 10, 8.0), + Row(" add", "World", 8, 3.0) + ), false) + } + + test("shuffledHashJoin and project fusion test") { + val omniResult = person_test.join(order_test.hint("SHUFFLE_HASH"), person_test("id_p") === order_test("id_p"), "inner") + .select(person_test("name"), order_test("order_no")) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, + s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") + checkAnswer(omniResult, _ => omniPlan, Seq( + Row("Carter", 44678), + Row("Carter", 77895), + Row("Adams", 22456), + Row("Adams", 24562) + ), false) + } + + test("ShuffledHashJoin and project fusion test for duplicate column") { + val omniResult = person_test.join(order_test.hint("SHUFFLE_HASH"), person_test("id_p") === order_test("id_p"), "inner") + .select(person_test("name"), order_test("order_no"), order_test("id_p")) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, + s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") + checkAnswer(omniResult, _ => omniPlan, Seq( + Row("Carter", 44678, 3), + Row("Carter", 77895, 3), + Row("Adams", 22456, 1), + Row("Adams", 24562, 1) + ), false) + } + + test("ShuffledHashJoin and project fusion test for reorder columns") { + val omniResult = person_test.join(order_test.hint("SHUFFLE_HASH"), person_test("id_p") === order_test("id_p"), "inner") + .select(order_test("order_no"), person_test("name"), order_test("id_p")) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, + s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") + checkAnswer(omniResult, _ => omniPlan, Seq( + Row(44678, "Carter", 3), + Row(77895, "Carter", 3), + Row(22456, "Adams", 1), + Row(24562, "Adams", 1) + ), false) + } + + test("ShuffledHashJoin and project are not fused test") { + val omniResult = person_test.join(order_test.hint("SHUFFLE_HASH"), person_test("id_p") === order_test("id_p"), "inner") + .select(order_test("order_no").plus(1), person_test("name")) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, + s"SQL:\n@OmniEnv have ColumnarProjectExec,omniPlan:${omniPlan}") + checkAnswer(omniResult, _ => omniPlan, Seq( + Row(44679, "Carter"), + Row(77896, "Carter"), + Row(22457, "Adams"), + Row(24563, "Adams") + ), false) + } + + test("ShuffledHashJoin and project fusion test for alias") { + val omniResult = person_test.join(order_test.hint("SHUFFLE_HASH"), person_test("id_p") === order_test("id_p"), "inner") + .select(person_test("name").as("name1"), order_test("order_no").as("order_no1")) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, + s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") + checkAnswer(omniResult, _ => omniPlan, Seq( + Row("Carter", 44678), + Row("Carter", 77895), + Row("Adams", 22456), + Row("Adams", 24562) + ), false) + } + + test("SortMergeJoin and project fusion test") { + spark.conf.set("spark.omni.sql.columnar.sortMergeJoin.fusion", false) + val omniResult = person_test.join(order_test.hint("MERGEJOIN"), person_test("id_p") === order_test("id_p"), "inner") + .select(person_test("name"), order_test("order_no")) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, + s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") + checkAnswer(omniResult, _ => omniPlan, Seq( + Row("Carter", 77895), + Row("Carter", 44678), + Row("Adams", 24562), + Row("Adams", 22456) + ), false) + } + + test("SortMergeJoin and project fusion test for duplicate column") { + spark.conf.set("spark.omni.sql.columnar.sortMergeJoin.fusion", false) + val omniResult = person_test.join(order_test.hint("MERGEJOIN"), person_test("id_p") === order_test("id_p"), "inner") + .select(person_test("name"), order_test("order_no"), order_test("id_p")) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, + s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") + checkAnswer(omniResult, _ => omniPlan, Seq( + Row("Carter", 77895, 3), + Row("Carter", 44678, 3), + Row("Adams", 24562, 1), + Row("Adams", 22456, 1) + ), false) + } + + test("SortMergeJoin and project fusion test for reorder columns") { + spark.conf.set("spark.omni.sql.columnar.sortMergeJoin.fusion", false) + val omniResult = person_test.join(order_test.hint("MERGEJOIN"), person_test("id_p") === order_test("id_p"), "inner") + .select(order_test("order_no"), person_test("name"), order_test("id_p")) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, + s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") + checkAnswer(omniResult, _ => omniPlan, Seq( + Row(77895, "Carter", 3), + Row(44678, "Carter", 3), + Row(24562, "Adams", 1), + Row(22456, "Adams", 1) + ), false) + } + + test("SortMergeJoin and project are not fused test") { + spark.conf.set("spark.omni.sql.columnar.sortMergeJoin.fusion", false) + val omniResult = person_test.join(order_test.hint("MERGEJOIN"), person_test("id_p") === order_test("id_p"), "inner") + .select(order_test("order_no").plus(1), person_test("name")) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, + s"SQL:\n@OmniEnv have ColumnarProjectExec,omniPlan:${omniPlan}") + checkAnswer(omniResult, _ => omniPlan, Seq( + Row(77896, "Carter"), + Row(44679, "Carter"), + Row(24563, "Adams"), + Row(22457, "Adams") + ), false) + } + + test("SortMergeJoin and project fusion test for alias") { + spark.conf.set("spark.omni.sql.columnar.sortMergeJoin.fusion", false) + val omniResult = person_test.join(order_test.hint("MERGEJOIN"), person_test("id_p") === order_test("id_p"), "inner") + .select(person_test("name").as("name1"), order_test("order_no").as("order_no1")) + val omniPlan = omniResult.queryExecution.executedPlan + assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, + s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") + checkAnswer(omniResult, _ => omniPlan, Seq( + Row("Carter", 77895), + Row("Carter", 44678), + Row("Adams", 24562), + Row("Adams", 22456) + ), false) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..8ff50e2673063c6e96abae5470a7098bc8683a30 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.types._ + +class ColumnarTopNSortExecSuite extends ColumnarSparkPlanTest { + + private var dealer: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + + dealer = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(1,"shanghai",10), + Row(2, "chengdu", 1), + Row(3,"guangzhou", 7), + Row(4, "beijing", 20), + Row(5, "hangzhou", 4), + Row(6, "tianjing", 3), + Row(7, "shenzhen", 5), + Row(8, "changsha", 5), + Row(9,"nanjing", 5), + Row(10, "wuhan", 6) + )),new StructType() + .add("id", IntegerType) + .add("city", StringType) + .add("sales", IntegerType)) + dealer.createOrReplaceTempView("dealer") + } + + test("Test topNSort") { + val sql1 ="select * from (SELECT city, rank() OVER (ORDER BY sales) AS rk FROM dealer) where rk<4 order by rk;" + assertColumnarTopNSortExecAndSparkResultEqual(sql1, true) + val sql2 = "select * from (SELECT city, row_number() OVER (ORDER BY sales) AS rn FROM dealer) where rn<4 order by rn;" + assertColumnarTopNSortExecAndSparkResultEqual(sql2, false) + val sql3 = "select * from (SELECT city, rank() OVER (PARTITION BY city ORDER BY sales) AS rk FROM dealer) where rk <4 order by rk;" + assertColumnarTopNSortExecAndSparkResultEqual(sql3, true) + } + + private def assertColumnarTopNSortExecAndSparkResultEqual(sql: String, hasColumnarTopNSortExec: Boolean = true): Unit = { + // run ColumnarTopNSortExec config + spark.conf.set("spark.omni.sql.columnar.topnsort", true) + spark.conf.set("conf spark.omni.sql.columnar.topnsortthreshold", 100) + val omniResult = spark.sql(sql) + val omniPlan = omniResult.queryExecution.executedPlan + if (hasColumnarTopNSortExec) { + assert(omniPlan.find(_.isInstanceOf[ColumnarTopNSortExec]).isDefined, + s"SQL:${sql}\n@OmniEnv no ColumnarTopNSortExec, omniPlan:${omniPlan}") + } + + // run TopNSortExec config + spark.conf.set("spark.omni.sql.columnar.topnsort", false) + val sparkResult = spark.sql(sql) + val sparkPlan = sparkResult.queryExecution.executedPlan + assert(sparkPlan.find(_.isInstanceOf[ColumnarTopNSortExec]).isEmpty, + s"SQL:${sql}\n@SparkEnv have ColumnarTopNSortExec, sparkPlan:${sparkPlan}") + // DataFrame do not support comparing with equals method, use DataFrame.except instead + // DataFrame.except can do equal for rows misorder(with and without order by are same) + assert(omniResult.except(sparkResult).isEmpty, + s"SQL:${sql}\nomniResult:${omniResult.show()}\nsparkResult:${sparkResult.show()}\n") + spark.conf.set("spark.omni.sql.columnar.topnsort", true) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarWindowExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarWindowExecSuite.scala index 4f11256f47f32ba8088a6bd3f201e53bddcfeb46..0700a83ba9ae788e7e4376cd16bd162552b798e3 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarWindowExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarWindowExecSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSparkSession @@ -46,4 +46,16 @@ class ColumnarWindowExecSuite extends ColumnarSparkPlanTest with SharedSparkSess res2.head(10).foreach(row => println(row)) assert(res2.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarWindowExec]).isDefined, s"ColumnarWindowExec not happened, executedPlan as follows: \n${res2.queryExecution.executedPlan}") } + + test("check columnar window result") { + val res1 = Window.partitionBy("a").orderBy('c.asc) + val res2 = inputDf.withColumn("max", max("c").over(res1)) + assert(res2.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isEmpty, s"ColumnarSortExec happened, executedPlan as follows: \n${res2.queryExecution.executedPlan}") + assert(res2.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarWindowExec]).isDefined, s"ColumnarWindowExec not happened, executedPlan as follows: \n${res2.queryExecution.executedPlan}") + checkAnswer( + res2, + Seq(Row(" add", "World", 8, 3.0, 8), Row(" yeah ", "yeah", 10, 8.0, 10), Row("abc", "", 4, 2.0, 4), + Row("abc", "", 10, 8.0, 10), Row("", "Hello", 1, 1.0, 1)) + ) + } } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala index cf2537484aefdcd68214db0046877652847bb34b..562d63db4a3796c8526811d1a959e2cd3b368fc9 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.log4j.Level import org.apache.spark.Partition -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{MapPartitionsRDD, RDD} import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.{Dataset, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} @@ -155,7 +155,12 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest val columnarCus = r.asInstanceOf[ColumnarCustomShuffleReaderExec] val rdd: RDD[ColumnarBatch] = columnarCus.executeColumnar() val parts: Array[Partition] = rdd.partitions - assert(parts.forall(rdd.preferredLocations(_).nonEmpty)) + rdd match { + case mapPartitionsRDD: MapPartitionsRDD[ColumnarBatch, ColumnarBatch] => + assert(parts.forall(mapPartitionsRDD.prev.preferredLocations(_).nonEmpty)) + case _ => + assert(parts.forall(rdd.asInstanceOf[ShuffledColumnarRDD].preferredLocations(_).nonEmpty)) + } } assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader)) } @@ -201,20 +206,27 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest case reader: ColumnarCustomShuffleReaderExec if reader.isLocalReader => reader } assert(localReaders.length == 2) - val localShuffleRDD0 = localReaders(0).executeColumnar().asInstanceOf[ShuffledColumnarRDD] - val localShuffleRDD1 = localReaders(1).executeColumnar().asInstanceOf[ShuffledColumnarRDD] - // The pre-shuffle partition size is [0, 0, 0, 72, 0] - // We exclude the 0-size partitions, so only one partition, advisoryParallelism = 1 - // the final parallelism is - // math.max(1, advisoryParallelism / numMappers): math.max(1, 1/2) = 1 - // and the partitions length is 1 * numMappers = 2 - assert(localShuffleRDD0.getPartitions.length == 2) - // The pre-shuffle partition size is [0, 72, 0, 72, 126] - // We exclude the 0-size partitions, so only 3 partition, advisoryParallelism = 3 - // the final parallelism is - // math.max(1, advisoryParallelism / numMappers): math.max(1, 3/2) = 1 - // and the partitions length is 1 * numMappers = 2 - assert(localShuffleRDD1.getPartitions.length == 2) + + val localRDD0 = localReaders(0).executeColumnar() + val localRDD1 = localReaders(1).executeColumnar() + localRDD0 match { + case mapPartitionsRDD: MapPartitionsRDD[ColumnarBatch, ColumnarBatch] => + // The pre-shuffle partition size is [0, 0, 0, 72, 0] + // We exclude the 0-size partitions, so only one partition, advisoryParallelism = 1 + // the final parallelism is + // math.max(1, advisoryParallelism / numMappers): math.max(1, 1/2) = 1 + // and the partitions length is 1 * numMappers = 2 + assert(localRDD0.asInstanceOf[MapPartitionsRDD[ColumnarBatch, ColumnarBatch]].getPartitions.length == 2) + // The pre-shuffle partition size is [0, 72, 0, 72, 126] + // We exclude the 0-size partitions, so only 3 partition, advisoryParallelism = 3 + // the final parallelism is + // math.max(1, advisoryParallelism / numMappers): math.max(1, 3/2) = 1 + // and the partitions length is 1 * numMappers = 2 + assert(localRDD1.asInstanceOf[MapPartitionsRDD[ColumnarBatch, ColumnarBatch]].getPartitions.length == 2) + case _ => + assert(localRDD0.asInstanceOf[ShuffledColumnarRDD].getPartitions.length == 2) + assert(localRDD1.asInstanceOf[ShuffledColumnarRDD].getPartitions.length == 2) + } } } @@ -233,14 +245,21 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest case reader: ColumnarCustomShuffleReaderExec if reader.isLocalReader => reader } assert(localReaders.length == 2) - val localShuffleRDD0 = localReaders(0).executeColumnar().asInstanceOf[ShuffledColumnarRDD] - val localShuffleRDD1 = localReaders(1).executeColumnar().asInstanceOf[ShuffledColumnarRDD] - // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 - // and the partitions length is 2 * numMappers = 4 - assert(localShuffleRDD0.getPartitions.length == 4) - // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 - // and the partitions length is 2 * numMappers = 4 - assert(localShuffleRDD1.getPartitions.length == 4) + + val localRDD0 = localReaders(0).executeColumnar() + val localRDD1 = localReaders(1).executeColumnar() + localRDD0 match { + case mapPartitionsRDD: MapPartitionsRDD[ColumnarBatch, ColumnarBatch] => + // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 + // and the partitions length is 2 * numMappers = 4 + assert(localRDD0.asInstanceOf[MapPartitionsRDD[ColumnarBatch, ColumnarBatch]].getPartitions.length == 4) + // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 + // and the partitions length is 2 * numMappers = 4 + assert(localRDD1.asInstanceOf[MapPartitionsRDD[ColumnarBatch, ColumnarBatch]].getPartitions.length == 4) + case _ => + assert(localRDD0.asInstanceOf[ShuffledColumnarRDD].getPartitions.length == 4) + assert(localRDD1.asInstanceOf[ShuffledColumnarRDD].getPartitions.length == 4) + } } } @@ -255,7 +274,7 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest .groupBy('a).count() checkAnswer(testDf, Seq()) val plan = testDf.queryExecution.executedPlan - assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined) + assert(find(plan)(_.isInstanceOf[ColumnarSortMergeJoinExec]).isDefined) val coalescedReaders = collect(plan) { case r: ColumnarCustomShuffleReaderExec => r } @@ -599,7 +618,7 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest "join testData2 t3 on t2.a = t3.a where t2.b = 1") val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 2) - val smj2 = findTopLevelSortMergeJoin(adaptivePlan) + val smj2 = findTopLevelColumnarSortMergeJoin(adaptivePlan) assert(smj2.size == 2, origPlan.toString) } } @@ -727,7 +746,7 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest .createOrReplaceTempView("skewData2") def checkSkewJoin( - joins: Seq[SortMergeJoinExec], + joins: Seq[ColumnarSortMergeJoinExec], leftSkewNum: Int, rightSkewNum: Int): Unit = { assert(joins.size == 1 && joins.head.isSkewJoin) diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarBuiltInFuncSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarBuiltInFuncSuite.scala index ce3e7ab8576a47a40e7a434847547b0978824043..20879ad520d2a4fc0f6af0f20e3720d64105e03e 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarBuiltInFuncSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarBuiltInFuncSuite.scala @@ -26,436 +26,612 @@ class ColumnarBuiltInFuncSuite extends ColumnarSparkPlanTest{ private var buildInDf: DataFrame = _ + private var buildInDfNum: DataFrame = _ + protected override def beforeAll(): Unit = { super.beforeAll() buildInDf = Seq[(String, String, String, String, Long, Int, String, String)]( - (null, "ChaR1 R", null, " varchar100 ", 1001L, 1, "中文1", "varchar100_normal"), - ("char200 ", "char2 ", "varchar2", "", 1002L, 2, "中文2", "varchar200_normal"), - ("char300 ", "char3 ", "varchar3", "varchar300", 1003L, 3, "中文3", "varchar300_normal"), - (null, "char4 ", "varchar4", "varchar400", 1004L, 4, "中文4", "varchar400_normal") + (null, "ChaR1 R", null, " varchar100 ", 1001L, 1, " 中文1aA ", "varchar100_normal"), + ("char200 ", "char2 ", "varchar2", "", 1002L, 2, "中文2bB", "varchar200_normal"), + ("char300 ", "char3 ", "varchar3", "varchar300", 1003L, 3, "中文3cC", "varchar300_normal"), + (null, "char4 ", "varchar4", "varchar400", 1004L, 4, null, "varchar400_normal") ).toDF("char_null", "char_normal", "varchar_null", "varchar_empty", "long_col", "int_col", "ch_col", "varchar_normal") buildInDf.createOrReplaceTempView("builtin_table") + + buildInDfNum = Seq[(Double, Int, Double, Int)]( + (123.12345, 1, -123.12345, 134), + (123.1257, 2, -123.1257, 1267), + (123.12, 3, -123.12, 1650), + (123.1, 4, -123.1, 166667) + ).toDF("double1", "int2", "double3", "int4") + buildInDfNum.createOrReplaceTempView("test_table") } test("Test ColumnarProjectExec happen and result is same as native " + "when execute lower with normal") { - val res = spark.sql("select lower(char_normal) from builtin_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row("char1 r"), - Row("char2 "), - Row("char3 "), - Row("char4 ") - ) + val sql = "select lower(char_normal) from builtin_table" + val expected = Seq( + Row("char1 r"), + Row("char2 "), + Row("char3 "), + Row("char4 ") ) + checkResult(sql, expected) } test("Test ColumnarProjectExec happen and result is same as native " + "when execute lower with null") { - val res = spark.sql("select lower(char_null) from builtin_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(null), - Row("char200 "), - Row("char300 "), - Row(null) - ) + val sql = "select lower(char_null) from builtin_table" + val expected = Seq( + Row(null), + Row("char200 "), + Row("char300 "), + Row(null) ) + checkResult(sql, expected) } test("Test ColumnarProjectExec happen and result is same as native " + "when execute lower with space/empty string") { - val res = spark.sql("select lower(varchar_empty) from builtin_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(" varchar100 "), - Row(""), - Row("varchar300"), - Row("varchar400") - ) + val sql = "select lower(varchar_empty) from builtin_table" + val expected = Seq( + Row(" varchar100 "), + Row(""), + Row("varchar300"), + Row("varchar400") ) + checkResult(sql, expected) } test("Test ColumnarProjectExec happen and result is same as native " + "when execute lower-lower") { - val res = spark.sql("select lower(char_null), lower(varchar_null) from builtin_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(null, null), - Row("char200 ", "varchar2"), - Row("char300 ", "varchar3"), - Row(null, "varchar4"), - ) + val sql = "select lower(char_null), lower(varchar_null) from builtin_table" + val expected = Seq( + Row(null, null), + Row("char200 ", "varchar2"), + Row("char300 ", "varchar3"), + Row(null, "varchar4"), ) + checkResult(sql, expected) } test("Test ColumnarProjectExec happen and result is same as native " + "when execute lower(lower())") { - val res = spark.sql("select lower(lower(char_null)) from builtin_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(null), - Row("char200 "), - Row("char300 "), - Row(null) - ) + val sql = "select lower(lower(char_null)) from builtin_table" + val expected = Seq( + Row(null), + Row("char200 "), + Row("char300 "), + Row(null) ) + checkResult(sql, expected) } test("Test ColumnarProjectExec happen and result is same as native " + "when execute lower with subQuery") { - val res = spark.sql("select lower(l) from (select lower(char_normal) as l from builtin_table)") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row("char1 r"), - Row("char2 "), - Row("char3 "), - Row("char4 ") - ) + val sql = "select lower(l) from (select lower(char_normal) as l from builtin_table)" + val expected = Seq( + Row("char1 r"), + Row("char2 "), + Row("char3 "), + Row("char4 ") ) + checkResult(sql, expected) } test("Test ColumnarProjectExec happen and result is same as native " + "when execute lower with ch") { - val res = spark.sql("select lower(ch_col) from builtin_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row("中文1"), - Row("中文2"), - Row("中文3"), - Row("中文4") - ) + val sql = "select lower(ch_col) from builtin_table" + val expected = Seq( + Row(" 中文1aa "), + Row("中文2bb"), + Row("中文3cc"), + Row(null) ) + checkResult(sql, expected) } test("Test ColumnarProjectExec happen and result is same as native " + "when execute length with normal") { - val res = spark.sql("select length(char_normal) from builtin_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(10), - Row(10), - Row(10), - Row(10) - ) + val sql = "select length(char_normal) from builtin_table" + val expected = Seq( + Row(10), + Row(10), + Row(10), + Row(10) ) + checkResult(sql, expected) } test("Test ColumnarProjectExec happen and result is same as native " + "when execute length with null") { - val res = spark.sql("select length(char_null) from builtin_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(null), - Row(10), - Row(10), - Row(null) - ) + val sql = "select length(char_null) from builtin_table" + val expected = Seq( + Row(null), + Row(10), + Row(10), + Row(null) ) + checkResult(sql, expected) } test("Test ColumnarProjectExec happen and result is same as native " + "when execute length with space/empty string") { - val res = spark.sql("select length(varchar_empty) from builtin_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(13), - Row(0), - Row(10), - Row(10) - ) + val sql = "select length(varchar_empty) from builtin_table" + val expected = Seq( + Row(13), + Row(0), + Row(10), + Row(10) ) + checkResult(sql, expected) } test("Test ColumnarProjectExec happen and result is same as native " + "when execute length with expr") { - val res = spark.sql("select length(char_null) / 2 from builtin_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(null), - Row(5.0), - Row(5.0), - Row(null) - ) + val sql = "select length(char_null) / 2 from builtin_table" + val expected = Seq( + Row(null), + Row(5.0), + Row(5.0), + Row(null) ) + checkResult(sql, expected) } test("Test ColumnarProjectExec happen and result is same as native " + "when execute length-length") { - val res = spark.sql("select length(char_null),length(varchar_null) from builtin_table") - val executedPlan = res.queryExecution.executedPlan - assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") - assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") - checkAnswer( - res, - Seq( - Row(null, null), - Row(10, 8), - Row(10, 8), - Row(null, 8) - ) + val sql = "select length(char_null),length(varchar_null) from builtin_table" + val expected = Seq( + Row(null, null), + Row(10, 8), + Row(10, 8), + Row(null, 8) ) + checkResult(sql, expected) } // replace(str, search, replaceStr) test("Test ColumnarProjectExec happen and result is same as native " + "when execute replace with matched and replace str") { - val res = spark.sql("select replace(varchar_normal,varchar_empty,char_normal) from builtin_table") + val sql = "select replace(varchar_normal,varchar_empty,char_normal) from builtin_table" + val expected = Seq( + Row("varchar100_normal"), + Row("varchar200_normal"), + Row("char3 _normal"), + Row("char4 _normal") + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute replace with not matched") { + val sql = "select replace(char_normal,varchar_normal,char_normal) from builtin_table" + val expected = Seq( + Row("ChaR1 R"), + Row("char2 "), + Row("char3 "), + Row("char4 ") + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute replace with str null") { + val sql = "select replace(varchar_null,char_normal,varchar_normal) from builtin_table" + val expected = Seq( + Row(null), + Row("varchar2"), + Row("varchar3"), + Row("varchar4") + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute replace with str space/empty") { + val sql = "select replace(varchar_empty,varchar_empty,varchar_normal) from builtin_table" + val expected = Seq( + Row("varchar100_normal"), + Row(""), + Row("varchar300_normal"), + Row("varchar400_normal") + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute replace with search null") { + val sql = "select replace(varchar_normal,varchar_null,char_normal) from builtin_table" + val expected = Seq( + Row(null), + Row("char2 00_normal"), + Row("char3 00_normal"), + Row("char4 00_normal") + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute replace with search space/empty") { + val sql = "select replace(varchar_normal,varchar_empty,char_normal) from builtin_table" + val expected = Seq( + Row("varchar100_normal"), + Row("varchar200_normal"), + Row("char3 _normal"), + Row("char4 _normal") + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute replace with replaceStr null") { + val sql = "select replace(varchar_normal,varchar_empty,varchar_null) from builtin_table" + val expected = Seq( + Row(null), + Row("varchar200_normal"), + Row("varchar3_normal"), + Row("varchar4_normal") + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute replace with replaceStr space/empty") { + val sql = "select replace(varchar_normal,varchar_normal,varchar_empty) from builtin_table" + val expected = Seq( + Row(" varchar100 "), + Row(""), + Row("varchar300"), + Row("varchar400") + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute replace with str/search/replace all null") { + val sql = "select replace(varchar_null,varchar_null,char_null) from builtin_table" + val expected = Seq( + Row(null), + Row("char200 "), + Row("char300 "), + Row(null) + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute replace with replaceStr default") { + val sql = "select replace(varchar_normal,varchar_normal) from builtin_table" + val expected = Seq( + Row(""), + Row(""), + Row(""), + Row("") + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute replace with subReplace(normal,normal,normal)") { + val sql = "select replace(res,'c','ccc') from (select replace(varchar_normal,varchar_empty,char_normal) as res from builtin_table)" + val expected = Seq( + Row("varccchar100_normal"), + Row("varccchar200_normal"), + Row("ccchar3 _normal"), + Row("ccchar4 _normal") + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute replace with subReplace(null,null,null)") { + val sql = "select replace(res,'c','ccc') from (select replace(varchar_null,varchar_null,char_null) as res from builtin_table)" + val expected = Seq( + Row(null), + Row("ccchar200 "), + Row("ccchar300 "), + Row(null) + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute replace(replace)") { + val sql = "select replace(replace('ABCabc','AB','abc'),'abc','DEF')" + val expected = Seq( + Row("DEFCDEF") + ) + checkResult(sql, expected) + } + + // upper + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute upper with normal") { + val sql = "select upper(char_normal) from builtin_table" + val expected = Seq( + Row("CHAR1 R"), + Row("CHAR2 "), + Row("CHAR3 "), + Row("CHAR4 ") + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute upper with null") { + val sql = "select upper(char_null) from builtin_table" + val expected = Seq( + Row(null), + Row("CHAR200 "), + Row("CHAR300 "), + Row(null) + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute upper with space/empty string") { + val sql = "select upper(varchar_empty) from builtin_table" + val expected = Seq( + Row(" VARCHAR100 "), + Row(""), + Row("VARCHAR300"), + Row("VARCHAR400") + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute upper-upper") { + val sql = "select upper(char_null), upper(varchar_null) from builtin_table" + val expected = Seq( + Row(null, null), + Row("CHAR200 ", "VARCHAR2"), + Row("CHAR300 ", "VARCHAR3"), + Row(null, "VARCHAR4"), + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute upper(upper())") { + val sql = "select upper(upper(char_null)) from builtin_table" + val expected = Seq( + Row(null), + Row("CHAR200 "), + Row("CHAR300 "), + Row(null) + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute upper with subQuery") { + val sql = "select upper(l) from (select upper(char_normal) as l from builtin_table)" + val expected = Seq( + Row("CHAR1 R"), + Row("CHAR2 "), + Row("CHAR3 "), + Row("CHAR4 ") + ) + checkResult(sql, expected) + } + + test("Test ColumnarProjectExec happen and result is same as native " + + "when execute upper with ch") { + val sql = "select upper(ch_col) from builtin_table" + val expected = Seq( + Row(" 中文1AA "), + Row("中文2BB"), + Row("中文3CC"), + Row(null) + ) + checkResult(sql, expected) + } + + def checkResult(sql: String, expected: Seq[Row], isUseOmni: Boolean = true): Unit = { + def assertOmniProjectHappen(res: DataFrame): Unit = { + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") + } + def assertOmniProjectNotHappen(res: DataFrame): Unit = { + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"ColumnarProjectExec happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isDefined, s"ProjectExec not happened, executedPlan as follows: \n$executedPlan") + } + val res = spark.sql(sql) + if (isUseOmni) assertOmniProjectHappen(res) else assertOmniProjectNotHappen(res) + checkAnswer(res, expected) + } + + test("Round(int,2)") { + val res = spark.sql("select round(int2,2) as res from test_table") val executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") checkAnswer( res, Seq( - Row("varchar100_normal"), - Row("varchar200_normal"), - Row("char3 _normal"), - Row("char4 _normal") + Row(1), + Row(2), + Row(3), + Row(4) ) ) } - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with not matched") { - val res = spark.sql("select replace(char_normal,varchar_normal,char_normal) from builtin_table") + test("Round(double,2)") { + val res = spark.sql("select round(double1,2) as res from test_table") val executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") checkAnswer( res, Seq( - Row("ChaR1 R"), - Row("char2 "), - Row("char3 "), - Row("char4 ") + Row(123.12), + Row(123.13), + Row(123.12), + Row(123.1) ) ) } - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with str null") { - val res = spark.sql("select replace(varchar_null,char_normal,varchar_normal) from builtin_table") + test("Round(int,-1)") { + val res = spark.sql("select round(int2,-1) as res from test_table") val executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") checkAnswer( res, Seq( - Row(null), - Row("varchar2"), - Row("varchar3"), - Row("varchar4") + Row(0), + Row(0), + Row(0), + Row(0) ) ) } - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with str space/empty") { - val res = spark.sql("select replace(varchar_empty,varchar_empty,varchar_normal) from builtin_table") + test("Round(double,0)") { + val res = spark.sql("select round(double1,0) as res from test_table") val executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") checkAnswer( res, Seq( - Row("varchar100_normal"), - Row(""), - Row("varchar300_normal"), - Row("varchar400_normal") + Row(123), + Row(123), + Row(123), + Row(123) ) ) } - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with search null") { - val res = spark.sql("select replace(varchar_normal,varchar_null,char_normal) from builtin_table") + test("Round(-double,2)") { + val res = spark.sql("select round(double3,2) as res from test_table") val executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") checkAnswer( res, Seq( - Row(null), - Row("char2 00_normal"), - Row("char3 00_normal"), - Row("char4 00_normal") + Row(-123.12), + Row(-123.13), + Row(-123.12), + Row(-123.1) ) ) } - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with search space/empty") { - val res = spark.sql("select replace(varchar_normal,varchar_empty,char_normal) from builtin_table") + test("Round(int,-2)") { + val res = spark.sql("select round(int4,-2) as res from test_table") val executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") checkAnswer( res, Seq( - Row("varchar100_normal"), - Row("varchar200_normal"), - Row("char3 _normal"), - Row("char4 _normal") + Row(100), + Row(1300), + Row(1700), + Row(166700) ) ) } - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with replaceStr null") { - val res = spark.sql("select replace(varchar_normal,varchar_empty,varchar_null) from builtin_table") - val executedPlan = res.queryExecution.executedPlan + test("Round decimal") { + var res = spark.sql("select round(2.5, 0) as res from test_table") + var executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") checkAnswer( res, Seq( - Row(null), - Row("varchar200_normal"), - Row("varchar3_normal"), - Row("varchar4_normal") + Row(3), + Row(3), + Row(3), + Row(3) ) ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with replaceStr space/empty") { - val res = spark.sql("select replace(varchar_normal,varchar_normal,varchar_empty) from builtin_table") - val executedPlan = res.queryExecution.executedPlan + res = spark.sql("select round(3.5, 0) as res from test_table") + executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") checkAnswer( res, Seq( - Row(" varchar100 "), - Row(""), - Row("varchar300"), - Row("varchar400") + Row(4), + Row(4), + Row(4), + Row(4) ) ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with str/search/replace all null") { - val res = spark.sql("select replace(varchar_null,varchar_null,char_null) from builtin_table") - val executedPlan = res.queryExecution.executedPlan + res = spark.sql("select round(-2.5, 0) as res from test_table") + executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") checkAnswer( res, Seq( - Row(null), - Row("char200 "), - Row("char300 "), - Row(null) + Row(-3), + Row(-3), + Row(-3), + Row(-3) ) ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with replaceStr default") { - val res = spark.sql("select replace(varchar_normal,varchar_normal) from builtin_table") - val executedPlan = res.queryExecution.executedPlan + res = spark.sql("select round(-3.5, 0) as res from test_table") + executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") checkAnswer( res, Seq( - Row(""), - Row(""), - Row(""), - Row("") + Row(-4), + Row(-4), + Row(-4), + Row(-4) ) ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with subReplace(normal,normal,normal)") { - val res = spark.sql("select replace(res,'c','ccc') from (select replace(varchar_normal,varchar_empty,char_normal) as res from builtin_table)") - val executedPlan = res.queryExecution.executedPlan + res = spark.sql("select round(-0.35, 1) as res from test_table") + executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") checkAnswer( res, Seq( - Row("varccchar100_normal"), - Row("varccchar200_normal"), - Row("ccchar3 _normal"), - Row("ccchar4 _normal") + Row(-0.4), + Row(-0.4), + Row(-0.4), + Row(-0.4) ) ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace with subReplace(null,null,null)") { - val res = spark.sql("select replace(res,'c','ccc') from (select replace(varchar_null,varchar_null,char_null) as res from builtin_table)") - val executedPlan = res.queryExecution.executedPlan + res = spark.sql("select round(-35, -1) as res from test_table") + executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") checkAnswer( res, Seq( - Row(null), - Row("ccchar200 "), - Row("ccchar300 "), - Row(null) + Row(-40), + Row(-40), + Row(-40), + Row(-40) ) ) - } - - test("Test ColumnarProjectExec happen and result is same as native " + - "when execute replace(replace)") { - val res = spark.sql("select replace(replace('ABCabc','AB','abc'),'abc','DEF')") - val executedPlan = res.queryExecution.executedPlan + res = spark.sql("select round(null, 0) as res from test_table") + executedPlan = res.queryExecution.executedPlan assert(executedPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"ColumnarProjectExec not happened, executedPlan as follows: \n$executedPlan") assert(executedPlan.find(_.isInstanceOf[ProjectExec]).isEmpty, s"ProjectExec happened, executedPlan as follows: \n$executedPlan") checkAnswer( res, Seq( - Row("DEFCDEF") + Row(null), + Row(null), + Row(null), + Row(null) ) ) } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarDecimalCastSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarDecimalCastSuite.scala index 1dcdada821f51e0fbfd799516aa6a7e7e1d8e449..2d56cac9dc777417bf9d44995a6b1cb089c67cfd 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarDecimalCastSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarDecimalCastSuite.scala @@ -426,7 +426,7 @@ class ColumnarDecimalCastSuite extends ColumnarSparkPlanTest{ "when cast double to decimal") { val res = spark.sql("select c_double_normal, cast(c_double_normal as decimal(8, 4))," + "cast(c_double_normal as decimal(32,4)) from deci_double") - assertOmniProjectHappened(res) + assertOmniProjectNotHappened(res) checkAnswer( res, Seq( @@ -441,7 +441,7 @@ class ColumnarDecimalCastSuite extends ColumnarSparkPlanTest{ "when cast double to decimal overflow with spark.sql.ansi.enabled=false") { val res = spark.sql("select c_double_normal, cast(c_double_normal as decimal(8, 6))," + "cast(c_double_normal as decimal(32,30)) from deci_double") - assertOmniProjectHappened(res) + assertOmniProjectNotHappened(res) checkAnswer( res, Seq( @@ -456,7 +456,7 @@ class ColumnarDecimalCastSuite extends ColumnarSparkPlanTest{ "when cast double to decimal with null") { val res = spark.sql("select c_double_null, cast(c_double_null as decimal(8, 4))," + "cast(c_double_null as decimal(34,4)) from deci_double") - assertOmniProjectHappened(res) + assertOmniProjectNotHappened(res) checkAnswer( res, Seq( diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml index 026fc59977b443256c933202f1ebb1dbc19ce3d7..df429265fa2abe86ed21db70b35af731605a6a81 100644 --- a/omnioperator/omniop-spark-extension/pom.xml +++ b/omnioperator/omniop-spark-extension/pom.xml @@ -8,7 +8,7 @@ com.huawei.kunpeng boostkit-omniop-spark-parent pom - 3.1.1-1.1.0 + 3.1.1-1.3.0 BoostKit Spark Native Sql Engine Extension Parent Pom @@ -16,12 +16,12 @@ 2.12.10 2.12 3.1.1 - 3.2.2 + 3.2.0 UTF-8 UTF-8 - 3.15.8 + 3.13.0-h19 FALSE - 1.1.0 + 1.3.0 java