diff --git a/omnioperator/omniop-openlookeng-extension/pom.xml b/omnioperator/omniop-openlookeng-extension/pom.xml index d10d2e43e2ed3c74589b1e2a9bc8ccf9b1923ba4..590586a18cf51d3f2f4276e43874e53c07681626 100644 --- a/omnioperator/omniop-openlookeng-extension/pom.xml +++ b/omnioperator/omniop-openlookeng-extension/pom.xml @@ -22,6 +22,7 @@ 3.1.2-1 2.10.0 1.0.0 + src/test/java @@ -194,6 +195,7 @@ boostkit-omniop-openlookeng-${openLooKeng.version}-${omniruntime.version}-aarch64 + ${test.source.dir} org.jacoco @@ -288,5 +290,13 @@ + + + omni-test + + src/test/omni + + + \ No newline at end of file diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java index 07d16bb753b8865be3ba6fb2f9c17b8479c0ddf1..60c60e9b6e2a93b071e2bd24e0426e82d683049b 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java @@ -80,9 +80,9 @@ import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.prestosql.spi.type.Decimals.MAX_SHORT_PRECISION; +import static io.prestosql.spi.type.DoubleType.DOUBLE; import static java.lang.Double.doubleToLongBits; import static java.lang.Double.longBitsToDouble; -import static javassist.bytecode.StackMap.DOUBLE; /** * The type Operator utils. diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/BlockUtil.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/BlockUtil.java new file mode 100644 index 0000000000000000000000000000000000000000..2cdb96b96886bfc4e81607b55dc24b3413a7dc94 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/BlockUtil.java @@ -0,0 +1,469 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk; + +import io.airlift.slice.Slice; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.DictionaryBlock; +import io.prestosql.spi.type.CharType; +import io.prestosql.spi.type.DecimalType; +import io.prestosql.spi.type.VarcharType; + +import java.math.BigInteger; +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.DateType.DATE; +import static io.prestosql.spi.type.Decimals.encodeUnscaledValue; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.RealType.REAL; +import static io.prestosql.spi.type.TimestampType.TIMESTAMP; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static java.lang.Float.floatToRawIntBits; + +public final class BlockUtil +{ + private BlockUtil() + { + } + + public static Block createStringSequenceBlock(int start, int end, VarcharType type) + { + BlockBuilder builder = type.createBlockBuilder(null, 100); + + for (int i = start; i < end; i++) { + type.writeString(builder, String.valueOf(i)); + } + + return builder.build(); + } + + public static Block createStringSequenceBlock(int start, int end, CharType type) + { + BlockBuilder builder = type.createBlockBuilder(null, 100); + + for (int i = start; i < end; i++) { + type.writeString(builder, String.valueOf(i)); + } + + return builder.build(); + } + + public static Block createIntegerSequenceBlock(int start, int end) + { + BlockBuilder builder = INTEGER.createFixedSizeBlockBuilder(end - start); + + for (int i = start; i < end; i++) { + INTEGER.writeLong(builder, i); + } + + return builder.build(); + } + + public static Block createStringDictionaryBlock(int start, int length, VarcharType type) + { + checkArgument(length > 5, "block must have more than 5 entries"); + + int dictionarySize = length / 5; + BlockBuilder builder = type.createBlockBuilder(null, dictionarySize); + for (int i = start; i < start + dictionarySize; i++) { + type.writeString(builder, String.valueOf(i)); + } + int[] ids = new int[length]; + for (int i = 0; i < length; i++) { + ids[i] = i % dictionarySize; + } + return new DictionaryBlock(builder.build(), ids); + } + + public static Block createStringDictionaryBlock(int start, int length, CharType type) + { + checkArgument(length > 5, "block must have more than 5 entries"); + + int dictionarySize = length / 5; + BlockBuilder builder = type.createBlockBuilder(null, dictionarySize); + for (int i = start; i < start + dictionarySize; i++) { + type.writeString(builder, String.valueOf(i)); + } + int[] ids = new int[length]; + for (int i = 0; i < length; i++) { + ids[i] = i % dictionarySize; + } + return new DictionaryBlock(builder.build(), ids); + } + + public static Block createLongDictionaryBlock(int start, int length) + { + checkArgument(length > 5, "block must have more than 5 entries"); + + int dictionarySize = length / 5; + BlockBuilder builder = BIGINT.createBlockBuilder(null, dictionarySize); + for (int i = start; i < start + dictionarySize; i++) { + BIGINT.writeLong(builder, i); + } + int[] ids = new int[length]; + for (int i = 0; i < length; i++) { + ids[i] = i % dictionarySize; + } + return new DictionaryBlock(builder.build(), ids); + } + + public static Block createIntegerDictionaryBlock(int start, int length) + { + checkArgument(length > 5, "block must have more than 5 entries"); + + int dictionarySize = length / 5; + BlockBuilder builder = INTEGER.createBlockBuilder(null, dictionarySize); + for (int i = start; i < start + dictionarySize; i++) { + INTEGER.writeLong(builder, i); + } + int[] ids = new int[length]; + for (int i = 0; i < length; i++) { + ids[i] = i % dictionarySize; + } + return new DictionaryBlock(builder.build(), ids); + } + + public static Block createRealDictionaryBlock(int start, int length) + { + checkArgument(length > 5, "block must have more than 5 entries"); + + int dictionarySize = length / 5; + BlockBuilder builder = REAL.createBlockBuilder(null, dictionarySize); + for (int i = start; i < start + dictionarySize; i++) { + REAL.writeLong(builder, floatToRawIntBits((float) i)); + } + int[] ids = new int[length]; + for (int i = 0; i < length; i++) { + ids[i] = i % dictionarySize; + } + return new DictionaryBlock(builder.build(), ids); + } + + public static Block createDoubleDictionaryBlock(int start, int length) + { + checkArgument(length > 5, "block must have more than 5 entries"); + + int dictionarySize = length / 5; + BlockBuilder builder = DOUBLE.createBlockBuilder(null, dictionarySize); + for (int i = start; i < start + dictionarySize; i++) { + DOUBLE.writeDouble(builder, i); + } + int[] ids = new int[length]; + for (int i = 0; i < length; i++) { + ids[i] = i % dictionarySize; + } + return new DictionaryBlock(builder.build(), ids); + } + + public static Block createBooleanDictionaryBlock(int start, int length) + { + checkArgument(length > 5, "block must have more than 5 entries"); + + int dictionarySize = length / 5; + BlockBuilder builder = BOOLEAN.createBlockBuilder(null, dictionarySize); + for (int i = start; i < start + dictionarySize; i++) { + BOOLEAN.writeBoolean(builder, i % 2 == 0); + } + int[] ids = new int[length]; + for (int i = 0; i < length; i++) { + ids[i] = i % dictionarySize; + } + return new DictionaryBlock(builder.build(), ids); + } + + public static Block createDateDictionaryBlock(int start, int length) + { + checkArgument(length > 5, "block must have more than 5 entries"); + + int dictionarySize = length / 5; + BlockBuilder builder = DATE.createBlockBuilder(null, dictionarySize); + for (int i = start; i < start + dictionarySize; i++) { + DATE.writeLong(builder, i); + } + int[] ids = new int[length]; + for (int i = 0; i < length; i++) { + ids[i] = i % dictionarySize; + } + return new DictionaryBlock(builder.build(), ids); + } + + public static Block createTimestampDictionaryBlock(int start, int length) + { + checkArgument(length > 5, "block must have more than 5 entries"); + + int dictionarySize = length / 5; + BlockBuilder builder = TIMESTAMP.createBlockBuilder(null, dictionarySize); + for (int i = start; i < start + dictionarySize; i++) { + TIMESTAMP.writeLong(builder, i); + } + int[] ids = new int[length]; + for (int i = 0; i < length; i++) { + ids[i] = i % dictionarySize; + } + return new DictionaryBlock(builder.build(), ids); + } + + public static Block createShortDecimalDictionaryBlock(int start, int length, DecimalType type) + { + checkArgument(length > 5, "block must have more than 5 entries"); + + int dictionarySize = length / 5; + long base = BigInteger.TEN.pow(type.getScale()).longValue(); + + BlockBuilder builder = type.createBlockBuilder(null, dictionarySize); + for (int i = start; i < start + dictionarySize; i++) { + type.writeLong(builder, base * i); + } + int[] ids = new int[length]; + for (int i = 0; i < length; i++) { + ids[i] = i % dictionarySize; + } + return new DictionaryBlock(builder.build(), ids); + } + + public static Block createLongDecimalDictionaryBlock(int start, int length, DecimalType type) + { + checkArgument(length > 5, "block must have more than 5 entries"); + + int dictionarySize = length / 5; + BigInteger base = BigInteger.TEN.pow(type.getScale()); + + BlockBuilder builder = type.createBlockBuilder(null, dictionarySize); + for (int i = start; i < start + dictionarySize; i++) { + type.writeSlice(builder, encodeUnscaledValue(BigInteger.valueOf(i).multiply(base))); + } + int[] ids = new int[length]; + for (int i = 0; i < length; i++) { + ids[i] = i % dictionarySize; + } + return new DictionaryBlock(builder.build(), ids); + } + + public static Block createIntegerBlock(List values) + { + int positionCount = values.size(); + BlockBuilder builder = INTEGER.createFixedSizeBlockBuilder(positionCount); + for (int i = 0; i < positionCount; i++) { + INTEGER.writeLong(builder, values.get(i)); + } + + return builder.build(); + } + + public static Block createLongBlock(List values) + { + int positionCount = values.size(); + BlockBuilder builder = BIGINT.createFixedSizeBlockBuilder(positionCount); + for (int i = 0; i < positionCount; i++) { + BIGINT.writeLong(builder, values.get(i)); + } + + return builder.build(); + } + + public static Block createRealBlock(List values) + { + int positionCount = values.size(); + BlockBuilder builder = REAL.createFixedSizeBlockBuilder(positionCount); + for (int i = 0; i < positionCount; i++) { + REAL.writeLong(builder, floatToRawIntBits((float) values.get(i))); + } + + return builder.build(); + } + + public static Block createDoubleBlock(List values) + { + int positionCount = values.size(); + BlockBuilder builder = DOUBLE.createFixedSizeBlockBuilder(positionCount); + for (int i = 0; i < positionCount; i++) { + DOUBLE.writeDouble(builder, (double) values.get(i)); + } + + return builder.build(); + } + + public static Block createStringBlock(String prefix, List values, VarcharType type) + { + int positionCount = values.size(); + BlockBuilder builder = type.createBlockBuilder(null, positionCount); + for (int i = 0; i < positionCount; i++) { + type.writeString(builder, prefix + values.get(i)); + } + + return builder.build(); + } + + public static Block createBooleanBlock(List values) + { + int positionCount = values.size(); + BlockBuilder builder = BOOLEAN.createFixedSizeBlockBuilder(positionCount); + for (int i = 0; i < positionCount; i++) { + BOOLEAN.writeBoolean(builder, values.get(i) == 0); + } + + return builder.build(); + } + + public static Block createDateBlock(List values) + { + int positionCount = values.size(); + BlockBuilder builder = DATE.createFixedSizeBlockBuilder(positionCount); + for (int i = 0; i < positionCount; i++) { + DATE.writeLong(builder, values.get(i)); + } + + return builder.build(); + } + + public static Block createTimestampBlock(List values) + { + int positionCount = values.size(); + BlockBuilder builder = TIMESTAMP.createFixedSizeBlockBuilder(positionCount); + for (int i = 0; i < positionCount; i++) { + TIMESTAMP.writeLong(builder, values.get(i)); + } + + return builder.build(); + } + + public static Block createShortDecimalBlock(List values, DecimalType type) + { + int positionCount = values.size(); + long base = BigInteger.TEN.pow(type.getScale()).longValue(); + BlockBuilder builder = type.createFixedSizeBlockBuilder(positionCount); + for (int i = 0; i < positionCount; i++) { + type.writeLong(builder, base * values.get(i)); + } + + return builder.build(); + } + + public static Block createLongDecimalBlock(List values, DecimalType type) + { + int positionCount = values.size(); + BigInteger base = BigInteger.TEN.pow(type.getScale()); + BlockBuilder builder = type.createFixedSizeBlockBuilder(positionCount); + for (int i = 0; i < positionCount; i++) { + type.writeSlice(builder, encodeUnscaledValue(BigInteger.valueOf(values.get(i)).multiply(base))); + } + + return builder.build(); + } + + private static Block createDictionaryBlock(Block block) + { + int dictionarySize = block.getPositionCount(); + int[] ids = new int[dictionarySize]; + for (int i = 0; i < dictionarySize; i++) { + ids[i] = i; + } + return new DictionaryBlock(block, ids); + } + + public static Block createIntegerDictionaryBlock(List values) + { + Block block = createIntegerBlock(values); + return createDictionaryBlock(block); + } + + public static Block createLongDictionaryBlock(List values) + { + Block block = createLongBlock(values); + return createDictionaryBlock(block); + } + + public static Block createRealDictionaryBlock(List values) + { + Block block = createRealBlock(values); + return createDictionaryBlock(block); + } + + public static Block createDoubleDictionaryBlock(List values) + { + Block block = createDoubleBlock(values); + return createDictionaryBlock(block); + } + + public static Block createStringDictionaryBlock(String prefix, List values, VarcharType type) + { + Block block = createStringBlock(prefix, values, type); + return createDictionaryBlock(block); + } + + public static Block createBooleanDictionaryBlock(List values) + { + Block block = createBooleanBlock(values); + return createDictionaryBlock(block); + } + + public static Block createDateDictionaryBlock(List values) + { + Block block = createDateBlock(values); + return createDictionaryBlock(block); + } + + public static Block createTimestampDictionaryBlock(List values) + { + Block block = createTimestampBlock(values); + return createDictionaryBlock(block); + } + + public static Block createShortDecimalDictionaryBlock(List values, DecimalType type) + { + Block block = createShortDecimalBlock(values, type); + return createDictionaryBlock(block); + } + + public static Block createLongDecimalDictionaryBlock(List values, DecimalType type) + { + Block block = createLongDecimalBlock(values, type); + return createDictionaryBlock(block); + } + + public static Block buildVarcharBlock(int rowSize, int width, int offset) + { + BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, rowSize); + for (int i = 0; i < rowSize; i++) { + VARCHAR.writeString(blockBuilder, createFixedWidthString(i, offset, width)); + } + return blockBuilder.build(); + } + + public static Slice[] getBlockSlices(Block block, int rowSize, int width) + { + Slice[] slice = new Slice[rowSize]; + for (int i = 0; i < rowSize; i++) { + slice[i] = block.getSlice(i, 0, width); + } + return slice; + } + + private static String createFixedWidthString(int index, int offset, int width) + { + String str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + StringBuilder stringBuilder = new StringBuilder(); + for (int j = 0; j < width; j++) { + stringBuilder.append(str.charAt((index + offset + j) % str.length())); + } + return stringBuilder.toString(); + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/PageBuilderUtil.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/PageBuilderUtil.java new file mode 100644 index 0000000000000000000000000000000000000000..ef799b5ab26b4b2929ca97fcbf2fbdf550a795fb --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/PageBuilderUtil.java @@ -0,0 +1,242 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk; + +import io.prestosql.block.BlockAssertions; +import io.prestosql.spi.Page; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.type.CharType; +import io.prestosql.spi.type.DecimalType; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.type.VarcharType; + +import java.util.List; + +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.DateType.DATE; +import static io.prestosql.spi.type.Decimals.isLongDecimal; +import static io.prestosql.spi.type.Decimals.isShortDecimal; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.RealType.REAL; +import static io.prestosql.spi.type.TimestampType.TIMESTAMP; + +public final class PageBuilderUtil +{ + private PageBuilderUtil() + { + } + + public static Page createSequencePage(List types, int length) + { + return createSequencePage(types, length, new int[types.size()]); + } + + public static Page createSequencePage(List types, int length, int... initialValues) + { + Block[] blocks = new Block[initialValues.length]; + for (int i = 0; i < blocks.length; i++) { + Type type = types.get(i); + int initialValue = initialValues[i]; + + if (type.equals(INTEGER)) { + blocks[i] = BlockUtil.createIntegerSequenceBlock(initialValue, initialValue + length); + } + else if (type.equals(BIGINT)) { + blocks[i] = BlockAssertions.createLongSequenceBlock(initialValue, initialValue + length); + } + else if (type.equals(REAL)) { + blocks[i] = BlockAssertions.createSequenceBlockOfReal(initialValue, initialValue + length); + } + else if (type.equals(DOUBLE)) { + blocks[i] = BlockAssertions.createDoubleSequenceBlock(initialValue, initialValue + length); + } + else if (type instanceof VarcharType) { + blocks[i] = BlockUtil.createStringSequenceBlock(initialValue, initialValue + length, + (VarcharType) type); + } + else if (type instanceof CharType) { + blocks[i] = BlockUtil.createStringSequenceBlock(initialValue, initialValue + length, + (CharType) type); + } + else if (type.equals(BOOLEAN)) { + blocks[i] = BlockAssertions.createBooleanSequenceBlock(initialValue, initialValue + length); + } + else if (type.equals(DATE)) { + blocks[i] = BlockAssertions.createDateSequenceBlock(initialValue, initialValue + length); + } + else if (type.equals(TIMESTAMP)) { + blocks[i] = BlockAssertions.createTimestampSequenceBlock(initialValue, initialValue + length); + } + else if (isShortDecimal(type)) { + blocks[i] = BlockAssertions.createShortDecimalSequenceBlock(initialValue, initialValue + length, + (DecimalType) type); + } + else if (isLongDecimal(type)) { + blocks[i] = BlockAssertions.createLongDecimalSequenceBlock(initialValue, initialValue + length, + (DecimalType) type); + } + else { + throw new IllegalStateException("Unsupported type " + type); + } + } + + return new Page(blocks); + } + + public static Page createSequencePageWithDictionaryBlocks(List types, int length) + { + return createSequencePageWithDictionaryBlocks(types, length, new int[types.size()]); + } + + public static Page createSequencePageWithDictionaryBlocks(List types, int length, int... initialValues) + { + Block[] blocks = new Block[initialValues.length]; + for (int i = 0; i < blocks.length; i++) { + Type type = types.get(i); + int initialValue = initialValues[i]; + if (type.equals(INTEGER)) { + blocks[i] = BlockUtil.createIntegerDictionaryBlock(initialValue, initialValue + length); + } + else if (type.equals(BIGINT)) { + blocks[i] = BlockUtil.createLongDictionaryBlock(initialValue, initialValue + length); + } + else if (type.equals(REAL)) { + blocks[i] = BlockUtil.createRealDictionaryBlock(initialValue, initialValue + length); + } + else if (type.equals(DOUBLE)) { + blocks[i] = BlockUtil.createDoubleDictionaryBlock(initialValue, initialValue + length); + } + else if (type instanceof VarcharType) { + blocks[i] = BlockUtil.createStringDictionaryBlock(initialValue, initialValue + length, + (VarcharType) type); + } + else if (type instanceof CharType) { + blocks[i] = BlockUtil.createStringDictionaryBlock(initialValue, initialValue + length, + (CharType) type); + } + else if (type.equals(BOOLEAN)) { + blocks[i] = BlockUtil.createBooleanDictionaryBlock(initialValue, initialValue + length); + } + else if (type.equals(DATE)) { + blocks[i] = BlockUtil.createDateDictionaryBlock(initialValue, initialValue + length); + } + else if (type.equals(TIMESTAMP)) { + blocks[i] = BlockUtil.createTimestampDictionaryBlock(initialValue, initialValue + length); + } + else if (isShortDecimal(type)) { + blocks[i] = BlockUtil.createShortDecimalDictionaryBlock(initialValue, initialValue + length, + (DecimalType) type); + } + else if (isLongDecimal(type)) { + blocks[i] = BlockUtil.createLongDecimalDictionaryBlock(initialValue, initialValue + length, + (DecimalType) type); + } + else { + throw new IllegalStateException("Unsupported type " + type); + } + } + + return new Page(blocks); + } + + public static Page createPage(List types, String prefix, List> columnValues) + { + Block[] blocks = new Block[types.size()]; + for (int i = 0; i < blocks.length; i++) { + Type type = types.get(i); + if (type.equals(INTEGER)) { + blocks[i] = BlockUtil.createIntegerBlock(columnValues.get(i)); + } + else if (type.equals(BIGINT)) { + blocks[i] = BlockUtil.createLongBlock(columnValues.get(i)); + } + else if (type.equals(REAL)) { + blocks[i] = BlockUtil.createRealBlock(columnValues.get(i)); + } + else if (type.equals(DOUBLE)) { + blocks[i] = BlockUtil.createDoubleBlock(columnValues.get(i)); + } + else if (type instanceof VarcharType) { + blocks[i] = BlockUtil.createStringBlock(prefix, columnValues.get(i), (VarcharType) type); + } + else if (type.equals(BOOLEAN)) { + blocks[i] = BlockUtil.createBooleanBlock(columnValues.get(i)); + } + else if (type.equals(DATE)) { + blocks[i] = BlockUtil.createDateBlock(columnValues.get(i)); + } + else if (type.equals(TIMESTAMP)) { + blocks[i] = BlockUtil.createTimestampBlock(columnValues.get(i)); + } + else if (isShortDecimal(type)) { + blocks[i] = BlockUtil.createShortDecimalBlock(columnValues.get(i), (DecimalType) type); + } + else if (isLongDecimal(type)) { + blocks[i] = BlockUtil.createLongDecimalBlock(columnValues.get(i), (DecimalType) type); + } + else { + throw new IllegalStateException("Unsupported type " + type); + } + } + + return new Page(blocks); + } + + public static Page createPageWithDictionaryBlocks(List types, String prefix, List> columnValues) + { + Block[] blocks = new Block[types.size()]; + for (int i = 0; i < blocks.length; i++) { + Type type = types.get(i); + if (type.equals(INTEGER)) { + blocks[i] = BlockUtil.createIntegerDictionaryBlock(columnValues.get(i)); + } + else if (type.equals(BIGINT)) { + blocks[i] = BlockUtil.createLongDictionaryBlock(columnValues.get(i)); + } + else if (type.equals(REAL)) { + blocks[i] = BlockUtil.createRealDictionaryBlock(columnValues.get(i)); + } + else if (type.equals(DOUBLE)) { + blocks[i] = BlockUtil.createDoubleDictionaryBlock(columnValues.get(i)); + } + else if (type instanceof VarcharType) { + blocks[i] = BlockUtil.createStringDictionaryBlock(prefix, columnValues.get(i), (VarcharType) type); + } + else if (type.equals(BOOLEAN)) { + blocks[i] = BlockUtil.createBooleanDictionaryBlock(columnValues.get(i)); + } + else if (type.equals(DATE)) { + blocks[i] = BlockUtil.createDateDictionaryBlock(columnValues.get(i)); + } + else if (type.equals(TIMESTAMP)) { + blocks[i] = BlockUtil.createTimestampDictionaryBlock(columnValues.get(i)); + } + else if (isShortDecimal(type)) { + blocks[i] = BlockUtil.createShortDecimalDictionaryBlock(columnValues.get(i), (DecimalType) type); + } + else if (isLongDecimal(type)) { + blocks[i] = BlockUtil.createLongDecimalDictionaryBlock(columnValues.get(i), (DecimalType) type); + } + else { + throw new IllegalStateException("Unsupported type " + type); + } + } + + return new Page(blocks); + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestByteArrayOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestByteArrayOmniBlock.java new file mode 100644 index 0000000000000000000000000000000000000000..4366a2cd8492a462bc91beb6eac3f8143a41ac0c --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestByteArrayOmniBlock.java @@ -0,0 +1,272 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.block; + +import io.airlift.slice.DynamicSliceOutput; +import io.prestosql.metadata.InternalBlockEncodingSerde; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.BlockEncodingSerde; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.util.BloomFilter; +import nova.hetu.olk.tool.OperatorUtils; +import nova.hetu.omniruntime.vector.BooleanVec; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.AssertJUnit; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.Random; +import java.util.concurrent.atomic.AtomicBoolean; + +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.spi.block.TestingSession.SESSION; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestByteArrayOmniBlock +{ + private final BlockEncodingSerde blockEncodingSerde = new InternalBlockEncodingSerde( + createTestMetadataManager().getFunctionAndTypeManager()); + + @Test + public void testBasicFunc() + { + // build vec through vec + Block baseBlock = buildBlockByBuilder(); + BooleanVec booleanVec = new BooleanVec(4); + booleanVec.set(0, false); + booleanVec.set(1, true); + booleanVec.set(2, false); + booleanVec.set(3, true); + ByteArrayOmniBlock byteArrayOmniBlock = new ByteArrayOmniBlock(4, booleanVec); + assertBlockEquals(BOOLEAN, byteArrayOmniBlock, baseBlock); + assertEquals(baseBlock.toString(), byteArrayOmniBlock.toString()); + + AtomicBoolean isIdentical = new AtomicBoolean(false); + byteArrayOmniBlock.retainedBytesForEachPart((part, size) -> { + if (size == booleanVec.getCapacityInBytes()) { + isIdentical.set(true); + } + }); + assertTrue(isIdentical.get()); + + Block byteArrayOmniBlockRegion = byteArrayOmniBlock.getRegion(2, 2); + assertEquals(byteArrayOmniBlockRegion.getPositionCount(), 2); + + for (int i = 0; i < byteArrayOmniBlockRegion.getPositionCount(); i++) { + assertEquals(byteArrayOmniBlockRegion.get(i), byteArrayOmniBlock.get(i + 2)); + } + + DynamicSliceOutput sliceOutput = new DynamicSliceOutput(1024); + blockEncodingSerde.writeBlock(sliceOutput, baseBlock); + Block actualBlock = blockEncodingSerde.readBlock(sliceOutput.slice().getInput()); + assertBlockEquals(actualBlock, (BooleanVec) baseBlock.getValues()); + + baseBlock.close(); + byteArrayOmniBlock.close(); + byteArrayOmniBlockRegion.close(); + actualBlock.close(); + } + + @Test + public void testInvalidInput() + { + byte[] bytes = {}; + byte[] values = {}; + assertThatThrownBy(() -> new ByteArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, -1, 1, bytes, values)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("arrayOffset is negative"); + assertThatThrownBy(() -> new ByteArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, -1, -1, bytes, values)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("arrayOffset is negative"); + assertThatThrownBy(() -> new ByteArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 1, 4, bytes, values)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("values length is less than positionCount"); + byte[] values2len = new byte[6]; + assertThatThrownBy(() -> new ByteArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 0, 4, bytes, values2len)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("isNull length is less than positionCount"); + + Block baseBlock = buildBlockByBuilder(); + BooleanVec booleanVec = (BooleanVec) baseBlock.getValues(); + byte[] bytes2array = {}; + assertThatThrownBy(() -> new ByteArrayOmniBlock(-1, 4, bytes2array, booleanVec)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("arrayOffset is negative"); + assertThatThrownBy(() -> new ByteArrayOmniBlock(1, -1, bytes2array, booleanVec)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("positionCount is negative"); + assertThatThrownBy(() -> new ByteArrayOmniBlock(1, 6, bytes2array, booleanVec)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("values length is less than positionCount"); + assertThatThrownBy(() -> new ByteArrayOmniBlock(1, 4, bytes2array, booleanVec)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("isNull length is less than positionCount"); + + baseBlock.close(); + } + + @Test + public void testGet() + { + BooleanVec booleanVec = new BooleanVec(4); + booleanVec.set(0, false); + booleanVec.set(1, true); + booleanVec.set(2, false); + booleanVec.set(3, true); + Block byteArrayOmniBlock = new ByteArrayOmniBlock(4, booleanVec); + long expect = 2L; + long expectSizeBytes = 8L; + long expectStates = 1L; + boolean[] position = {true, true, true, true}; + assertEquals(byteArrayOmniBlock.getRegionSizeInBytes(0, 1), expect); + assertEquals(byteArrayOmniBlock.getRegionSizeInBytes(0, 4), expectSizeBytes); + assertEquals(byteArrayOmniBlock.getEstimatedDataSizeForStats(0), expectStates); + assertEquals(byteArrayOmniBlock.getPositionsSizeInBytes(position), expectSizeBytes); + byteArrayOmniBlock.close(); + } + + @Test + public void testCopyRegion() + { + BooleanVec booleanVec = new BooleanVec(4); + booleanVec.set(0, false); + booleanVec.set(1, true); + booleanVec.set(2, false); + booleanVec.set(3, true); + Block byteArrayOmniBlock = new ByteArrayOmniBlock(4, booleanVec); + Block copyRegionBlock = byteArrayOmniBlock.copyRegion(0, byteArrayOmniBlock.getPositionCount()); + assertBlockEquals(copyRegionBlock, (BooleanVec) byteArrayOmniBlock.getValues()); + + Block copyNotEqualRegionBlock = byteArrayOmniBlock.copyRegion(0, 3); + assertBlockEquals(copyNotEqualRegionBlock, (BooleanVec) byteArrayOmniBlock.getValues()); + + copyNotEqualRegionBlock.close(); + } + + @Test + public void testCopyPosition() + { + BooleanVec booleanVec = new BooleanVec(4); + booleanVec.set(0, false); + booleanVec.set(1, true); + booleanVec.set(2, false); + booleanVec.set(3, true); + Block byteArrayOmniBlock = new ByteArrayOmniBlock(4, booleanVec); + + int[] positions = {0, 2, 3}; + Block copyPositionsBlock = byteArrayOmniBlock.copyPositions(positions, 0, 3); + for (int i = 0; i < 3; i++) { + assertEquals(copyPositionsBlock.getByte(i, 0), byteArrayOmniBlock.getByte(positions[i], 0)); + } + byteArrayOmniBlock.close(); + copyPositionsBlock.close(); + } + + @Test + public void testMultipleValuesWithNull() + { + BlockBuilder blockBuilder = BOOLEAN.createBlockBuilder(null, 4); + blockBuilder.appendNull(); + BOOLEAN.writeBoolean(blockBuilder, false); + blockBuilder.appendNull(); + BOOLEAN.writeBoolean(blockBuilder, false); + Block block = OperatorUtils.buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, blockBuilder.build()); + + ByteArrayOmniBlock nullByteArrayOmniBlock = new ByteArrayOmniBlock(4, (BooleanVec) block.getValues()); + // build block from vec + AssertJUnit.assertTrue(nullByteArrayOmniBlock.isNull(0)); + AssertJUnit.assertTrue(nullByteArrayOmniBlock.isNull(2)); + + nullByteArrayOmniBlock.close(); + } + + @Test + public void testFilter() + { + int count = 4; + int size = 4; + boolean[] valid = new boolean[count]; + Arrays.fill(valid, Boolean.TRUE); + ByteArrayOmniBlock block = getBlock(count); + Byte[] values = new Byte[block.getPositionCount()]; + + BloomFilter bf = getBf(size); + for (int i = 0; i < block.getPositionCount(); i++) { + values[i] = (block.getByte(i, 0)); + } + boolean[] actualValidPositions = block.filter(bf, valid); + assertEquals(actualValidPositions, valid); + + int[] positions = {0, 1, 2, 3}; + int positionCount = 4; + int[] matchedPosition = new int[4]; + int actualFilterPositions = block.filter(positions, positionCount, matchedPosition, (x) -> { + return true; + }); + assertEquals(actualFilterPositions, positionCount); + + block.close(); + } + + private Block buildBlockByBuilder() + { + BlockBuilder blockBuilder = BOOLEAN.createBlockBuilder(null, 4); + BOOLEAN.writeBoolean(blockBuilder, false); + BOOLEAN.writeBoolean(blockBuilder, true); + BOOLEAN.writeBoolean(blockBuilder, false); + BOOLEAN.writeBoolean(blockBuilder, true); + return OperatorUtils.buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, blockBuilder.build()); + } + + private ByteArrayOmniBlock getBlock(int count) + { + BooleanVec booleanVec = new BooleanVec(count); + for (int i = 0; i < count; i++) { + if ((i & 1) == 1) { + booleanVec.set(i, true); + } + else { + booleanVec.set(i, false); + } + } + return new ByteArrayOmniBlock(count, booleanVec); + } + + private BloomFilter getBf(int size) + { + Random rnd = new Random(); + BloomFilter bf = new BloomFilter(size, 0.01); + for (int i = 0; i < 100; i++) { + bf.test(("value" + rnd.nextLong()).getBytes()); + } + return bf; + } + + private static void assertBlockEquals(Block actual, BooleanVec expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(actual.get(position), (expected.get(position)) ? (byte) 1 : (byte) 0); + } + } + + private static void assertBlockEquals(Type type, Block actual, Block expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(type.getObjectValue(SESSION, actual, position), + type.getObjectValue(SESSION, expected, position)); + } + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestDictionaryOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestDictionaryOmniBlock.java new file mode 100644 index 0000000000000000000000000000000000000000..13ce260949e723ea189617fa278ba813462f4b80 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestDictionaryOmniBlock.java @@ -0,0 +1,286 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.block; + +import io.airlift.slice.DynamicSliceOutput; +import io.prestosql.metadata.InternalBlockEncodingSerde; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.BlockEncodingSerde; +import io.prestosql.spi.block.DictionaryId; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.util.BloomFilter; +import nova.hetu.omniruntime.vector.DictionaryVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.Random; +import java.util.concurrent.atomic.AtomicBoolean; + +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.spi.block.TestingSession.SESSION; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static nova.hetu.olk.tool.OperatorUtils.buildOffHeapBlock; +import static nova.hetu.olk.tool.OperatorUtils.buildOnHeapBlock; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestDictionaryOmniBlock +{ + private final BlockEncodingSerde blockEncodingSerde = new InternalBlockEncodingSerde( + createTestMetadataManager().getFunctionAndTypeManager()); + + @Test + public void testBasicFunc() + { + // build vec through vec + int[] ids = {0, 1, 2, 3}; + Block baseBlock = buildBlockByBuilder(); + DictionaryOmniBlock dictionaryOmniBlock = new DictionaryOmniBlock((Vec) baseBlock.getValues(), ids); + assertBlockEquals(VARCHAR, dictionaryOmniBlock, baseBlock); + + AtomicBoolean isIdentical = new AtomicBoolean(false); + dictionaryOmniBlock.retainedBytesForEachPart((part, size) -> { + if (part == ids) { + isIdentical.set(true); + } + }); + assertTrue(isIdentical.get()); + + DictionaryId dictionaryId = dictionaryOmniBlock.getDictionarySourceId(); + int[] positions = new int[11]; + DictionaryOmniBlock interceptBlock = (DictionaryOmniBlock) dictionaryOmniBlock.getPositions(positions, 0, 4); + assertEquals(interceptBlock.getDictionarySourceId(), dictionaryId); + + Block regionDicOmniBlock = dictionaryOmniBlock.getRegion(2, 2); + assertEquals(regionDicOmniBlock.getPositionCount(), 2); + for (int i = 0; i < regionDicOmniBlock.getPositionCount(); i++) { + assertEquals(regionDicOmniBlock.get(i), dictionaryOmniBlock.get(i + 2)); + } + + DynamicSliceOutput sliceOutput = new DynamicSliceOutput(1024); + blockEncodingSerde.writeBlock(sliceOutput, buildOnHeapBlock(dictionaryOmniBlock)); + Block actualBlock = blockEncodingSerde.readBlock(sliceOutput.slice().getInput()); + assertBlockEquals(actualBlock, (VarcharVec) baseBlock.getValues()); + + baseBlock.close(); + dictionaryOmniBlock.close(); + regionDicOmniBlock.close(); + interceptBlock.close(); + actualBlock.close(); + } + + @Test + public void testCopyRegion() + { + Block baseBlock = buildBlockByBuilder(); + int[] ids = {0, 1, 2, 3}; + Block dictionaryOmniBlock = new DictionaryOmniBlock((Vec) baseBlock.getValues(), ids); + Block copyRegionBlock = dictionaryOmniBlock.copyRegion(0, dictionaryOmniBlock.getPositionCount()); + assertBlockEquals(copyRegionBlock, + (VarcharVec) ((DictionaryOmniBlock) dictionaryOmniBlock).getDictionary().getValues()); + + Block compactDicBlock = buildBlock2Compact(); + int[] ids1 = {0, 1, 2, 3}; + Block newBlock2Compact = new DictionaryOmniBlock((Vec) compactDicBlock.getValues(), ids1); + Block copyRegionBlock2Compact = newBlock2Compact.copyRegion(0, newBlock2Compact.getPositionCount()); + assertBlockEquals(copyRegionBlock2Compact, + (VarcharVec) ((DictionaryOmniBlock) newBlock2Compact).getDictionary().getValues()); + + baseBlock.close(); + dictionaryOmniBlock.close(); + copyRegionBlock.close(); + compactDicBlock.close(); + newBlock2Compact.close(); + copyRegionBlock2Compact.close(); + } + + @Test + public void testCopyPosition() + { + Block baseBlock = buildBlockByBuilder(); + int[] ids = {0, 1, 2, 3}; + Block dictionaryOmniBlock = new DictionaryOmniBlock((Vec) baseBlock.getValues(), ids); + int[] positions = {0, 2, 3}; + Block copyRegionBlock = dictionaryOmniBlock.copyPositions(positions, 0, 3); + for (int i = 0; i < 3; i++) { + assertEquals(copyRegionBlock.getString(i, 0, 0), dictionaryOmniBlock.getString(positions[i], 0, 0)); + } + + baseBlock.close(); + dictionaryOmniBlock.close(); + copyRegionBlock.close(); + } + + @Test + public void testFilter() + { + int count = 4; + int size = 1000; + boolean[] valid = new boolean[count]; + Arrays.fill(valid, Boolean.TRUE); + DictionaryOmniBlock block = getBlock(count); + String[] values = new String[block.getPositionCount()]; + + BloomFilter bf = getBf(size); + for (int i = 0; i < block.getPositionCount(); i++) { + values[i] = block.getString(i, 0, 0); + } + + boolean[] actualValidPositions = block.filter(bf, valid); + assertEquals(actualValidPositions, valid); + + int[] positions = {0, 1, 2, 3}; + int positionCount = 4; + int[] matchedPosition = new int[4]; + int actualFilterPositions = block.filter(positions, positionCount, matchedPosition, (x) -> { + return true; + }); + assertEquals(actualFilterPositions, positionCount); + block.close(); + } + + @Test + public void testGet() + { + Block baseBlock = buildBlockByBuilder(); + int[] ids = {0, 1, 2, 3}; + DictionaryOmniBlock dictionaryOmniBlock = new DictionaryOmniBlock((Vec) baseBlock.getValues(), ids); + long expect = 14L; + long expectSizeBytes = 55L; + long expectStates = 5L; + boolean[] position = {true, true, true, true}; + long expect2LogicalSizeInBytes = 39L; + String expectStr = "DictionaryOmniBlock{positionCount=4}"; + assertEquals(dictionaryOmniBlock.toString(), expectStr); + assertEquals(dictionaryOmniBlock.getLogicalSizeInBytes(), expect2LogicalSizeInBytes); + assertEquals(dictionaryOmniBlock.getRegionSizeInBytes(0, 1), expect); + assertEquals(dictionaryOmniBlock.getRegionSizeInBytes(0, 4), expectSizeBytes); + assertEquals(dictionaryOmniBlock.getEstimatedDataSizeForStats(0), expectStates); + assertEquals(dictionaryOmniBlock.getPositionsSizeInBytes(position), expectSizeBytes); + + Block loadedOmniBlock = dictionaryOmniBlock.getLoadedBlock(); + assertBlockEquals(VARCHAR, dictionaryOmniBlock, loadedOmniBlock); + + baseBlock.close(); + loadedOmniBlock.close(); + } + + @Test + public void testInvalidInput() + { + Block baseBlock = buildBlockByBuilder(); + int[] ids = {0, 1, 2, 3}; + DictionaryOmniBlock throwDicOmniBlock = new DictionaryOmniBlock((Vec) baseBlock.getValues(), ids); + assertThatThrownBy(() -> new DictionaryOmniBlock(1, -1, (Vec) baseBlock.getValues(), ids, false, + throwDicOmniBlock.getDictionarySourceId())).isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("positionCount is negative"); + assertThatThrownBy(() -> new DictionaryOmniBlock(1, 4, (Vec) baseBlock.getValues(), ids, false, + throwDicOmniBlock.getDictionarySourceId())).isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("ids length is less than positionCount"); + assertThatThrownBy(() -> new DictionaryOmniBlock(1, -1, (DictionaryVec) throwDicOmniBlock.getValues(), ids, + baseBlock, false, throwDicOmniBlock.getDictionarySourceId())) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("positionCount is negative"); + assertThatThrownBy(() -> new DictionaryOmniBlock(1, 4, (DictionaryVec) throwDicOmniBlock.getValues(), ids, + baseBlock, false, throwDicOmniBlock.getDictionarySourceId())) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("ids length is less than positionCount"); + + baseBlock.close(); + throwDicOmniBlock.close(); + } + + @Test + public void testMultipleValuesWithNull() + { + BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, 10); + blockBuilder.appendNull(); + VARCHAR.writeString(blockBuilder, "alice"); + blockBuilder.appendNull(); + VARCHAR.writeString(blockBuilder, "bob"); + Block block = buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, blockBuilder.build()); + + int[] ids = {0, 1, 2, 3}; + DictionaryOmniBlock nullDicOmniBlock = new DictionaryOmniBlock((Vec) block.getValues(), ids); + // build block from vec + assertTrue(nullDicOmniBlock.isNull(0)); + assertTrue(nullDicOmniBlock.isNull(2)); + + nullDicOmniBlock.close(); + block.close(); + } + + private Block buildBlockByBuilder() + { + BlockBuilder dictionaryOmniBuilder = VARCHAR.createBlockBuilder(null, 4); + VARCHAR.writeString(dictionaryOmniBuilder, "alice"); + VARCHAR.writeString(dictionaryOmniBuilder, "bob"); + VARCHAR.writeString(dictionaryOmniBuilder, "charlie"); + VARCHAR.writeString(dictionaryOmniBuilder, "dave"); + return buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, dictionaryOmniBuilder.build()); + } + + private Block buildBlock2Compact() + { + BlockBuilder dictionaryOmniBuilder = VARCHAR.createBlockBuilder(null, 5); + VARCHAR.writeString(dictionaryOmniBuilder, "alice"); + VARCHAR.writeString(dictionaryOmniBuilder, "bob"); + VARCHAR.writeString(dictionaryOmniBuilder, "charlie"); + VARCHAR.writeString(dictionaryOmniBuilder, "dave"); + VARCHAR.writeString(dictionaryOmniBuilder, "dave"); + return buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, dictionaryOmniBuilder.build()); + } + + private DictionaryOmniBlock getBlock(int count) + { + Block baseBlock = buildBlockByBuilder(); + int[] ids = {0, 1, 2, 3}; + Vec dictionary = (Vec) baseBlock.getValues(); + DictionaryOmniBlock dictionaryOmniBlock = new DictionaryOmniBlock(count, dictionary, ids); + baseBlock.close(); + return dictionaryOmniBlock; + } + + private BloomFilter getBf(int size) + { + Random rnd = new Random(); + BloomFilter bf = new BloomFilter(size, 0.01); + for (int i = 0; i < 100; i++) { + bf.test(("value" + rnd.nextLong()).getBytes()); + } + return bf; + } + + private static void assertBlockEquals(Block actual, VarcharVec expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(new String((byte[]) actual.get(position)), new String(expected.get(position))); + } + } + + private static void assertBlockEquals(Type type, Block actual, Block expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(type.getObjectValue(SESSION, actual, position), + type.getObjectValue(SESSION, expected, position)); + } + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestDoubleArrayOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestDoubleArrayOmniBlock.java new file mode 100644 index 0000000000000000000000000000000000000000..b4048940a7fc25cf5f09057541feddd56efce987 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestDoubleArrayOmniBlock.java @@ -0,0 +1,221 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.block; + +import io.airlift.slice.DynamicSliceOutput; +import io.prestosql.metadata.InternalBlockEncodingSerde; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.BlockEncodingSerde; +import io.prestosql.spi.type.Type; +import nova.hetu.olk.tool.OperatorUtils; +import nova.hetu.omniruntime.vector.DoubleVec; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.Test; + +import java.util.concurrent.atomic.AtomicBoolean; + +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.spi.block.TestingSession.SESSION; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestDoubleArrayOmniBlock +{ + private final BlockEncodingSerde blockEncodingSerde = new InternalBlockEncodingSerde( + createTestMetadataManager().getFunctionAndTypeManager()); + + @Test(enabled = false) + public void testMultipleValuesWithNull() + { + BlockBuilder blockBuilder = DOUBLE.createBlockBuilder(null, 10); + blockBuilder.appendNull(); + DOUBLE.writeDouble(blockBuilder, 42.33); + blockBuilder.appendNull(); + DOUBLE.writeDouble(blockBuilder, Double.MAX_VALUE); + Block onHeapBlock = blockBuilder.build(); + + Block block = OperatorUtils.buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, onHeapBlock); + + assertTrue(block.isNull(0)); + assertEquals(DOUBLE.getDouble(block, 1), 42.33); + assertTrue(block.isNull(2)); + assertEquals(DOUBLE.getDouble(block, 3), Double.MAX_VALUE); + + // build block from vec + Block block1 = new DoubleArrayOmniBlock(4, (DoubleVec) block.getValues()); + assertTrue(block1.isNull(0)); + assertEquals(DOUBLE.getDouble(block1, 1), 42.33); + assertTrue(block1.isNull(2)); + assertEquals(DOUBLE.getDouble(block1, 3), Double.MAX_VALUE); + block.close(); + } + + @Test + public void testBasicFunc() + { + // build vec through vec + Block baseBlock = buildBlockByVec(); + DoubleArrayOmniBlock doubleArrayOmniBlock = new DoubleArrayOmniBlock(baseBlock.getPositionCount(), + (DoubleVec) baseBlock.getValues()); + assertBlockEquals(DOUBLE, baseBlock, doubleArrayOmniBlock); + assertEquals(baseBlock.toString(), doubleArrayOmniBlock.toString()); + + AtomicBoolean isIdentical = new AtomicBoolean(false); + doubleArrayOmniBlock.retainedBytesForEachPart((part, size) -> { + if (size == ((DoubleVec) baseBlock.getValues()).getCapacityInBytes()) { + isIdentical.set(true); + } + }); + assertTrue(isIdentical.get()); + + Block regionOmniBlock = doubleArrayOmniBlock.getRegion(2, 2); + assertEquals(regionOmniBlock.getPositionCount(), 2); + for (int i = 0; i < regionOmniBlock.getPositionCount(); i++) { + assertEquals(regionOmniBlock.get(i), doubleArrayOmniBlock.get(i + 2)); + } + + DynamicSliceOutput sliceOutput = new DynamicSliceOutput(1024); + Block serBlock = OperatorUtils.buildOnHeapBlock(baseBlock); + blockEncodingSerde.writeBlock(sliceOutput, serBlock); + Block deSerBlock = blockEncodingSerde.readBlock(sliceOutput.slice().getInput()); + Block actualBlock = OperatorUtils.buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, deSerBlock, + deSerBlock.getClass().getSimpleName(), deSerBlock.getPositionCount(), DOUBLE); + assertBlockEquals(actualBlock, (DoubleVec) baseBlock.getValues()); + + doubleArrayOmniBlock.close(); + regionOmniBlock.close(); + actualBlock.close(); + } + + @Test + public void testCopyRegion() + { + Block baseBlock = buildBlockByVec(); + Block doubleArrayOmniBlock = new DoubleArrayOmniBlock(baseBlock.getPositionCount(), + (DoubleVec) baseBlock.getValues()); + Block copyRegionBlock = doubleArrayOmniBlock.copyRegion(0, doubleArrayOmniBlock.getPositionCount()); + assertBlockEquals(copyRegionBlock, (DoubleVec) doubleArrayOmniBlock.getValues()); + + Block copyNotEqualRegionBlock = doubleArrayOmniBlock.copyRegion(0, 3); + assertBlockEquals(copyNotEqualRegionBlock, (DoubleVec) doubleArrayOmniBlock.getValues()); + + copyNotEqualRegionBlock.close(); + } + + @Test + public void testCopyPosition() + { + Block baseBlock = buildBlockByVec(); + Block doubleArrayOmniBlock = new DoubleArrayOmniBlock(baseBlock.getPositionCount(), + (DoubleVec) baseBlock.getValues()); + + int[] positions = {0, 2, 3}; + Block copyRegionBlock = doubleArrayOmniBlock.copyPositions(positions, 0, 3); + for (int i = 0; i < 3; i++) { + assertEquals(copyRegionBlock.getDouble(i, 0), doubleArrayOmniBlock.getDouble(positions[i], 0)); + } + + doubleArrayOmniBlock.close(); + copyRegionBlock.close(); + } + + @Test + public void testInvalidInput() + { + byte[] bytes = {}; + double[] values = {}; + assertThatThrownBy(() -> new DoubleArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, -1, 1, bytes, values)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("arrayOffset is negative"); + + assertThatThrownBy(() -> new DoubleArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 1, -1, bytes, values)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("positionCount is negative"); + + assertThatThrownBy(() -> new DoubleArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 1, 4, bytes, values)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("values length is less than positionCount"); + + double[] values2len = new double[6]; + assertThatThrownBy(() -> new DoubleArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 1, 4, bytes, values2len)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("isNull length is less than positionCount"); + + Block baseBlock = buildBlockByVec(); + DoubleVec expected = (DoubleVec) baseBlock.getValues(); + byte[] bytes2array = {}; + assertThatThrownBy(() -> new DoubleArrayOmniBlock(-1, 4, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("arrayOffset is negative"); + + assertThatThrownBy(() -> new DoubleArrayOmniBlock(1, -1, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("positionCount is negative"); + + assertThatThrownBy(() -> new DoubleArrayOmniBlock(1, 6, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("values length is less than positionCount"); + + assertThatThrownBy(() -> new DoubleArrayOmniBlock(1, 4, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("isNull length is less than positionCount"); + + baseBlock.close(); + } + + @Test + public void testGet() + { + Block baseBlock = buildBlockByVec(); + Block doubleArrayOmniBlock = new DoubleArrayOmniBlock(baseBlock.getPositionCount(), + (DoubleVec) baseBlock.getValues()); + long expect = 9; + long expectSizeBytes = 36; + long expectStates = 8; + boolean[] position = {true, true, true, true}; + assertEquals(doubleArrayOmniBlock.getRegionSizeInBytes(0, 1), expect); + assertEquals(doubleArrayOmniBlock.getRegionSizeInBytes(0, 4), expectSizeBytes); + assertEquals(doubleArrayOmniBlock.getEstimatedDataSizeForStats(0), expectStates); + assertEquals(doubleArrayOmniBlock.getPositionsSizeInBytes(position), expectSizeBytes); + + doubleArrayOmniBlock.close(); + } + + private Block buildBlockByVec() + { + DoubleVec doubleVec = new DoubleVec(4); + doubleVec.set(0, 42.33); + doubleVec.set(0, 43.34); + doubleVec.set(0, 44.35); + doubleVec.set(0, 45.36); + + return new DoubleArrayOmniBlock(4, doubleVec); + } + + private static void assertBlockEquals(Block actual, DoubleVec expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(actual.get(position), new Double(expected.get(position))); + } + } + + private static void assertBlockEquals(Type type, Block actual, Block expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(type.getObjectValue(SESSION, actual, position), + type.getObjectValue(SESSION, expected, position)); + } + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestInt128ArrayOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestInt128ArrayOmniBlock.java new file mode 100644 index 0000000000000000000000000000000000000000..79a1739774920ccf03661feb7efe3cae7b6a5e3d --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestInt128ArrayOmniBlock.java @@ -0,0 +1,202 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.block; + +import io.airlift.slice.DynamicSliceOutput; +import io.prestosql.metadata.InternalBlockEncodingSerde; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockEncodingSerde; +import nova.hetu.omniruntime.vector.Decimal128Vec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.Test; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; + +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestInt128ArrayOmniBlock +{ + private final BlockEncodingSerde blockEncodingSerde = new InternalBlockEncodingSerde( + createTestMetadataManager().getFunctionAndTypeManager()); + + @Test + public void testMultipleValuesWithNull() + { + int positionCount = 4; + long[] values = {0L, 0L, 0L, 42L, 0L, 0L, Long.MAX_VALUE, Long.MAX_VALUE}; + byte[] valueIsNull = {Vec.NULL, Vec.NOT_NULL, Vec.NULL, Vec.NOT_NULL}; + Int128ArrayOmniBlock block = new Int128ArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, positionCount, + Optional.of(valueIsNull), values); + + assertTrue(block.isNull(0)); + assertEquals(block.get(1), new long[]{0L, 42L}); + assertTrue(block.isNull(2)); + assertEquals(block.get(3), new long[]{Long.MAX_VALUE, Long.MAX_VALUE}); + + // build block from vec + Block block1 = new Int128ArrayOmniBlock(4, (Decimal128Vec) block.getValues()); + assertTrue(block1.isNull(0)); + assertEquals(block1.get(1), new long[]{0L, 42L}); + assertTrue(block1.isNull(2)); + assertEquals(block.get(3), new long[]{Long.MAX_VALUE, Long.MAX_VALUE}); + block.close(); + } + + @Test + public void testBasicFunc() + { + // build vec through vec + Block baseBlock = buildBlock(); + Int128ArrayOmniBlock int128ArrayOmniBlock = new Int128ArrayOmniBlock(baseBlock.getPositionCount(), + (Decimal128Vec) baseBlock.getValues()); + assertBlockEquals(int128ArrayOmniBlock, baseBlock); + String expect = "Int128ArrayOmniBlock{positionCount=4}"; + assertEquals(int128ArrayOmniBlock.toString(), expect); + + AtomicBoolean isIdentical = new AtomicBoolean(false); + int128ArrayOmniBlock.retainedBytesForEachPart((part, size) -> { + if (part == baseBlock.getValues()) { + isIdentical.set(true); + } + }); + assertTrue(isIdentical.get()); + + Block regionInt128ArrayOmniBlock = int128ArrayOmniBlock.getRegion(2, 2); + assertEquals(regionInt128ArrayOmniBlock.getPositionCount(), 2); + for (int i = 0; i < regionInt128ArrayOmniBlock.getPositionCount(); i++) { + assertEquals(regionInt128ArrayOmniBlock.get(i), int128ArrayOmniBlock.get(i + 2)); + } + + DynamicSliceOutput sliceOutput = new DynamicSliceOutput(1024); + blockEncodingSerde.writeBlock(sliceOutput, baseBlock); + Block actualBlock = blockEncodingSerde.readBlock(sliceOutput.slice().getInput()); + assertBlockEquals(actualBlock, baseBlock); + + int128ArrayOmniBlock.close(); + regionInt128ArrayOmniBlock.close(); + actualBlock.close(); + } + + @Test + public void testCopyRegion() + { + Block baseBlock = buildBlock(); + Block int128ArrayOmniBlock = new Int128ArrayOmniBlock(baseBlock.getPositionCount(), + (Decimal128Vec) baseBlock.getValues()); + Block copyRegionBlock = int128ArrayOmniBlock.copyRegion(0, int128ArrayOmniBlock.getPositionCount()); + assertBlockEquals(copyRegionBlock, int128ArrayOmniBlock); + + Block copyNotEqualRegionBlock = int128ArrayOmniBlock.copyRegion(0, 3); + assertBlockEquals(copyNotEqualRegionBlock, int128ArrayOmniBlock); + + copyNotEqualRegionBlock.close(); + } + + @Test + public void testCopyPosition() + { + Block baseBlock = buildBlock(); + Block int128ArrayOmniBlock = new Int128ArrayOmniBlock(baseBlock.getPositionCount(), + (Decimal128Vec) baseBlock.getValues()); + + int[] positions = {0, 1, 2, 3}; + Block copyRegionBlock = int128ArrayOmniBlock.copyPositions(positions, 0, 4); + assertBlockEquals(copyRegionBlock, int128ArrayOmniBlock); + + int128ArrayOmniBlock.close(); + copyRegionBlock.close(); + } + + @Test + public void testInvalidInput() + { + byte[] bytes = {}; + long[] values = {}; + assertThatThrownBy(() -> new Int128ArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, -1, 1, bytes, values)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("positionOffset is negative"); + + assertThatThrownBy(() -> new Int128ArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 1, -1, bytes, values)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("positionCount is negative"); + + assertThatThrownBy(() -> new Int128ArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 1, 4, bytes, values)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("values length is less than positionCount"); + + long[] values2len = new long[6]; + assertThatThrownBy(() -> new Int128ArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 1, 1, bytes, values2len)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("isNull length is less than positionCount"); + + Block baseBlock = buildBlock(); + Decimal128Vec expected = (Decimal128Vec) baseBlock.getValues(); + byte[] bytes2array = {}; + assertThatThrownBy(() -> new Int128ArrayOmniBlock(-1, 4, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("positionOffset is negative"); + + assertThatThrownBy(() -> new Int128ArrayOmniBlock(1, -1, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("positionCount is negative"); + + assertThatThrownBy(() -> new Int128ArrayOmniBlock(1, 6, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("values length is less than positionCount"); + + assertThatThrownBy(() -> new Int128ArrayOmniBlock(1, 4, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("isNull length is less than positionCount"); + + baseBlock.close(); + } + + @Test + public void testGet() + { + Block baseBlock = buildBlock(); + Block int128ArrayOmniBlock = new Int128ArrayOmniBlock(baseBlock.getPositionCount(), + (Decimal128Vec) baseBlock.getValues()); + long expect = 17; + long expectSizeBytes = 68; + long expectStates = 0; + boolean[] position = {true, true, true, true}; + assertEquals(int128ArrayOmniBlock.getRegionSizeInBytes(0, 1), expect); + assertEquals(int128ArrayOmniBlock.getRegionSizeInBytes(0, 4), expectSizeBytes); + assertEquals(int128ArrayOmniBlock.getEstimatedDataSizeForStats(0), expectStates); + assertEquals(int128ArrayOmniBlock.getPositionsSizeInBytes(position), expectSizeBytes); + + int128ArrayOmniBlock.close(); + } + + private Block buildBlock() + { + int positionCount = 4; + long[] values = {0L, 0L, 0L, 42L, 0L, 0L, Long.MAX_VALUE, Long.MAX_VALUE}; + byte[] valueIsNull = {Vec.NULL, Vec.NOT_NULL, Vec.NULL, Vec.NOT_NULL}; + Int128ArrayOmniBlock block = new Int128ArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, positionCount, + Optional.of(valueIsNull), values); + return block; + } + + private static void assertBlockEquals(Block actual, Block expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(actual.get(position), expected.get(position)); + } + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestIntArrayOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestIntArrayOmniBlock.java new file mode 100644 index 0000000000000000000000000000000000000000..3ef9377ae81f6c7811d9a0ae2a75466ff64a3f7f --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestIntArrayOmniBlock.java @@ -0,0 +1,213 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.block; + +import io.airlift.slice.DynamicSliceOutput; +import io.prestosql.metadata.InternalBlockEncodingSerde; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.BlockEncodingSerde; +import io.prestosql.spi.type.Type; +import nova.hetu.olk.tool.OperatorUtils; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.Test; + +import java.util.concurrent.atomic.AtomicBoolean; + +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.spi.block.TestingSession.SESSION; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestIntArrayOmniBlock +{ + private final BlockEncodingSerde blockEncodingSerde = new InternalBlockEncodingSerde( + createTestMetadataManager().getFunctionAndTypeManager()); + + @Test + public void testMultipleValuesWithNull() + { + BlockBuilder blockBuilder = INTEGER.createBlockBuilder(null, 10); + blockBuilder.appendNull(); + INTEGER.writeLong(blockBuilder, 42); + blockBuilder.appendNull(); + INTEGER.writeLong(blockBuilder, Integer.MAX_VALUE); + Block onHeapBlock = blockBuilder.build(); + Block block = OperatorUtils.buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, onHeapBlock); + + assertTrue(block.isNull(0)); + assertEquals(INTEGER.getLong(block, 1), 42); + assertTrue(block.isNull(2)); + assertEquals(INTEGER.getLong(block, 3), Integer.MAX_VALUE); + + // build block from vec + Block nullOmniBlock = new IntArrayOmniBlock(4, (IntVec) block.getValues()); + assertTrue(nullOmniBlock.isNull(0)); + assertEquals(INTEGER.getLong(nullOmniBlock, 1), 42); + assertTrue(nullOmniBlock.isNull(2)); + assertEquals(INTEGER.getLong(nullOmniBlock, 3), Integer.MAX_VALUE); + + block.close(); + } + + @Test + public void testBasicFunc() + { + // build vec through vec + Block baseBlock = buildBlockByBuilder(); + IntArrayOmniBlock intArrayOmniBlock = new IntArrayOmniBlock(baseBlock.getPositionCount(), + (IntVec) baseBlock.getValues()); + assertBlockEquals(INTEGER, intArrayOmniBlock, baseBlock); + assertEquals(baseBlock.toString(), intArrayOmniBlock.toString()); + + AtomicBoolean isIdentical = new AtomicBoolean(false); + intArrayOmniBlock.retainedBytesForEachPart((part, size) -> { + if (size == ((IntVec) baseBlock.getValues()).getCapacityInBytes()) { + isIdentical.set(true); + } + }); + assertTrue(isIdentical.get()); + Block regionOmniBlock = intArrayOmniBlock.getRegion(2, 2); + assertEquals(regionOmniBlock.getPositionCount(), 2); + for (int i = 0; i < regionOmniBlock.getPositionCount(); i++) { + assertEquals(regionOmniBlock.get(i), intArrayOmniBlock.get(i + 2)); + } + + DynamicSliceOutput sliceOutput = new DynamicSliceOutput(1024); + blockEncodingSerde.writeBlock(sliceOutput, baseBlock); + Block actualBlock = blockEncodingSerde.readBlock(sliceOutput.slice().getInput()); + assertBlockEquals(actualBlock, (IntVec) baseBlock.getValues()); + + intArrayOmniBlock.close(); + regionOmniBlock.close(); + actualBlock.close(); + } + + @Test + public void testCopyPosition() + { + Block baseBlock = buildBlockByBuilder(); + Block intArrayOmniBlock = new IntArrayOmniBlock(baseBlock.getPositionCount(), (IntVec) baseBlock.getValues()); + + int[] positions = {0, 2, 3}; + Block copyPositionsBlock = intArrayOmniBlock.copyPositions(positions, 0, 3); + for (int i = 0; i < 3; i++) { + assertEquals(copyPositionsBlock.getInt(i, 0), intArrayOmniBlock.getInt(positions[i], 0)); + } + intArrayOmniBlock.close(); + copyPositionsBlock.close(); + } + + @Test + public void testCopyRegion() + { + Block baseBlock = buildBlockByBuilder(); + Block intArrayOmniBlock = new IntArrayOmniBlock(baseBlock.getPositionCount(), (IntVec) baseBlock.getValues()); + Block copyRegionBlock = intArrayOmniBlock.copyRegion(0, intArrayOmniBlock.getPositionCount()); + assertBlockEquals(copyRegionBlock, (IntVec) intArrayOmniBlock.getValues()); + + Block copyNotEqualRegionBlock = intArrayOmniBlock.copyRegion(0, 3); + assertBlockEquals(copyNotEqualRegionBlock, (IntVec) intArrayOmniBlock.getValues()); + + copyNotEqualRegionBlock.close(); + } + + @Test + public void testInvalidInput() + { + byte[] bytes = {}; + int[] values = {}; + + assertThatThrownBy(() -> new IntArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, -1, 1, bytes, values)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("arrayOffset is negative"); + + assertThatThrownBy(() -> new IntArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 1, -1, bytes, values)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("positionCount is negative"); + + assertThatThrownBy(() -> new IntArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 1, 4, bytes, values)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("values length is less than positionCount"); + + int[] values2len = new int[6]; + assertThatThrownBy(() -> new IntArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 1, 4, bytes, values2len)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("isNull length is less than positionCount"); + + Block baseBlock = buildBlockByBuilder(); + IntVec expected = (IntVec) baseBlock.getValues(); + byte[] bytes2array = {}; + assertThatThrownBy(() -> new IntArrayOmniBlock(-1, 4, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("arrayOffset is negative"); + + assertThatThrownBy(() -> new IntArrayOmniBlock(1, -1, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("positionCount is negative"); + + assertThatThrownBy(() -> new IntArrayOmniBlock(1, 6, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("values length is less than positionCount"); + + assertThatThrownBy(() -> new IntArrayOmniBlock(1, 4, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("isNull length is less than positionCount"); + + baseBlock.close(); + } + + @Test + public void testGet() + { + Block baseBlock = buildBlockByBuilder(); + Block intArrayOmniBlock = new IntArrayOmniBlock(baseBlock.getPositionCount(), (IntVec) baseBlock.getValues()); + long expect = 5; + long expectSizeBytes = 20; + long expectStates = 4; + boolean[] position = {true, true, true, true}; + assertEquals(intArrayOmniBlock.getRegionSizeInBytes(0, 1), expect); + assertEquals(intArrayOmniBlock.getRegionSizeInBytes(0, 4), expectSizeBytes); + assertEquals(intArrayOmniBlock.getEstimatedDataSizeForStats(0), expectStates); + assertEquals(intArrayOmniBlock.getPositionsSizeInBytes(position), expectSizeBytes); + + intArrayOmniBlock.close(); + } + + private Block buildBlockByBuilder() + { + BlockBuilder blockBuilder = INTEGER.createBlockBuilder(null, 4); + INTEGER.writeLong(blockBuilder, 42); + INTEGER.writeLong(blockBuilder, 43); + INTEGER.writeLong(blockBuilder, 44); + INTEGER.writeLong(blockBuilder, 45); + return OperatorUtils.buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, blockBuilder.build()); + } + + private static void assertBlockEquals(Block actual, IntVec expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals((Integer) actual.get(position), new Integer(expected.get(position))); + } + } + + private static void assertBlockEquals(Type type, Block actual, Block expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(type.getObjectValue(SESSION, actual, position), + type.getObjectValue(SESSION, expected, position)); + } + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestLongArrayOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestLongArrayOmniBlock.java new file mode 100644 index 0000000000000000000000000000000000000000..518332fc610a8be17677a8073b6cc906399a6e39 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestLongArrayOmniBlock.java @@ -0,0 +1,216 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.block; + +import io.airlift.slice.DynamicSliceOutput; +import io.prestosql.metadata.InternalBlockEncodingSerde; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.BlockEncodingSerde; +import io.prestosql.spi.type.Type; +import nova.hetu.olk.tool.OperatorUtils; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.Test; + +import java.util.concurrent.atomic.AtomicBoolean; + +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.spi.block.TestingSession.SESSION; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestLongArrayOmniBlock +{ + private final BlockEncodingSerde blockEncodingSerde = new InternalBlockEncodingSerde( + createTestMetadataManager().getFunctionAndTypeManager()); + + @Test + public void testMultipleValuesWithNull() + { + BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, 10); + blockBuilder.appendNull(); + BIGINT.writeLong(blockBuilder, 42); + blockBuilder.appendNull(); + BIGINT.writeLong(blockBuilder, Long.MAX_VALUE); + Block onHeapBlock = blockBuilder.build(); + Block block = OperatorUtils.buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, onHeapBlock); + + assertTrue(block.isNull(0)); + assertEquals(BIGINT.getLong(block, 1), 42L); + assertTrue(block.isNull(2)); + assertEquals(BIGINT.getLong(block, 3), Long.MAX_VALUE); + + // build block from vec + Block longArrayOmniBlock = new LongArrayOmniBlock(4, (LongVec) block.getValues()); + assertTrue(longArrayOmniBlock.isNull(0)); + assertEquals(BIGINT.getLong(longArrayOmniBlock, 1), 42L); + assertTrue(longArrayOmniBlock.isNull(2)); + assertEquals(BIGINT.getLong(longArrayOmniBlock, 3), Long.MAX_VALUE); + + block.close(); + } + + @Test + public void testBasicFunc() + { + // build vec through vec + Block baseBlock = buildBlockByBuilder(); + LongArrayOmniBlock longArrayOmniBlock = new LongArrayOmniBlock(baseBlock.getPositionCount(), + (LongVec) baseBlock.getValues()); + assertBlockEquals(BIGINT, longArrayOmniBlock, baseBlock); + assertEquals(baseBlock.toString(), longArrayOmniBlock.toString()); + + Block regionOmniBlock = longArrayOmniBlock.getRegion(2, 2); + assertEquals(regionOmniBlock.getPositionCount(), 2); + for (int i = 0; i < regionOmniBlock.getPositionCount(); i++) { + assertEquals(regionOmniBlock.get(i), longArrayOmniBlock.get(i + 2)); + } + + AtomicBoolean isIdentical = new AtomicBoolean(false); + longArrayOmniBlock.retainedBytesForEachPart((part, size) -> { + if (size == ((LongVec) baseBlock.getValues()).getCapacityInBytes()) { + isIdentical.set(true); + } + }); + assertTrue(isIdentical.get()); + + DynamicSliceOutput sliceOutput = new DynamicSliceOutput(1024); + blockEncodingSerde.writeBlock(sliceOutput, baseBlock); + Block actualBlock = blockEncodingSerde.readBlock(sliceOutput.slice().getInput()); + assertBlockEquals(actualBlock, (LongVec) baseBlock.getValues()); + + longArrayOmniBlock.close(); + regionOmniBlock.close(); + actualBlock.close(); + } + + @Test + public void testCopyRegion() + { + Block baseBlock = buildBlockByBuilder(); + Block longArrayOmniBlock = new LongArrayOmniBlock(baseBlock.getPositionCount(), + (LongVec) baseBlock.getValues()); + Block copyRegionBlock = longArrayOmniBlock.copyRegion(0, longArrayOmniBlock.getPositionCount()); + assertBlockEquals(copyRegionBlock, (LongVec) longArrayOmniBlock.getValues()); + + Block copyNotEqualRegionBlock = longArrayOmniBlock.copyRegion(0, 3); + assertBlockEquals(copyNotEqualRegionBlock, (LongVec) longArrayOmniBlock.getValues()); + + copyNotEqualRegionBlock.close(); + } + + @Test + public void testCopyPosition() + { + Block baseBlock = buildBlockByBuilder(); + Block longArrayOmniBlock = new LongArrayOmniBlock(baseBlock.getPositionCount(), (LongVec) baseBlock.getValues()); + + int[] positions = {0, 2, 3}; + Block copyRegionBlock = longArrayOmniBlock.copyPositions(positions, 0, 3); + for (int i = 0; i < 3; i++) { + assertEquals(copyRegionBlock.getLong(i, 0), longArrayOmniBlock.getLong(positions[i], 0)); + } + longArrayOmniBlock.close(); + copyRegionBlock.close(); + } + + @Test + public void testInvalidInput() + { + Block baseBlock = buildBlockByBuilder(); + byte[] bytes = {}; + long[] values = {}; + + assertThatThrownBy(() -> new LongArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, -1, 1, bytes, values)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("arrayOffset is negative"); + + assertThatThrownBy(() -> new LongArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 1, -1, bytes, values)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("positionCount is negative"); + + assertThatThrownBy(() -> new LongArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 1, 4, bytes, values)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("values length is less than positionCount"); + + long[] values2len = new long[6]; + assertThatThrownBy(() -> new LongArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 1, 4, bytes, values2len)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("isNull length is less than positionCount"); + + LongVec expected = (LongVec) baseBlock.getValues(); + byte[] bytes2array = {}; + assertThatThrownBy(() -> new LongArrayOmniBlock(-1, 4, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("arrayOffset is negative"); + + assertThatThrownBy(() -> new LongArrayOmniBlock(1, -1, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("positionCount is negative"); + + assertThatThrownBy(() -> new LongArrayOmniBlock(1, 6, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("values length is less than positionCount"); + + assertThatThrownBy(() -> new LongArrayOmniBlock(1, 4, bytes2array, expected)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("isNull length is less than positionCount"); + + baseBlock.close(); + } + + @Test + public void testGet() + { + Block baseBlock = buildBlockByBuilder(); + Block longArrayOmniBlock = new LongArrayOmniBlock(baseBlock.getPositionCount(), + (LongVec) baseBlock.getValues()); + long expect = 9; + long expectSizeBytes = 36; + long expectStates = 8; + boolean[] position = {true, true, true, true}; + assertEquals(longArrayOmniBlock.getRegionSizeInBytes(0, 1), expect); + assertEquals(longArrayOmniBlock.getRegionSizeInBytes(0, 4), expectSizeBytes); + assertEquals(longArrayOmniBlock.getEstimatedDataSizeForStats(0), expectStates); + assertEquals(longArrayOmniBlock.getPositionsSizeInBytes(position), expectSizeBytes); + + longArrayOmniBlock.close(); + } + + private Block buildBlockByBuilder() + { + BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, 4); + BIGINT.writeLong(blockBuilder, 42); + BIGINT.writeLong(blockBuilder, 43); + BIGINT.writeLong(blockBuilder, 44); + BIGINT.writeLong(blockBuilder, 45); + return OperatorUtils.buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, blockBuilder.build()); + } + + private static void assertBlockEquals(Block actual, LongVec expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals((Long) actual.get(position), new Long(expected.get(position))); + } + } + + private static void assertBlockEquals(Type type, Block actual, Block expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(type.getObjectValue(SESSION, actual, position), + type.getObjectValue(SESSION, expected, position)); + } + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestRowOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestRowOmniBlock.java new file mode 100644 index 0000000000000000000000000000000000000000..aa3798c4e9b7357da9c1b4c8ade89b00f2fd4e9a --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestRowOmniBlock.java @@ -0,0 +1,151 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.block; + +import io.airlift.slice.DynamicSliceOutput; +import io.prestosql.metadata.InternalBlockEncodingSerde; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.BlockEncodingSerde; +import io.prestosql.spi.block.RowBlock; +import io.prestosql.spi.type.Type; +import nova.hetu.olk.tool.OperatorUtils; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.Test; + +import java.util.concurrent.atomic.AtomicBoolean; + +import static io.airlift.slice.SizeOf.sizeOf; +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.spi.block.TestingSession.SESSION; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestRowOmniBlock +{ + private final BlockEncodingSerde blockEncodingSerde = new InternalBlockEncodingSerde( + createTestMetadataManager().getFunctionAndTypeManager()); + + @Test + public void testBasicFunc() + { + // build vec through vec + RowOmniBlock baseBlock = (RowOmniBlock) buildBlockByBuilder(); + RowOmniBlock loadedBlock = (RowOmniBlock) baseBlock.getLoadedBlock(); + Block varcharBlock = baseBlock.getRawFieldBlocks()[0]; + Block varcharBlock2 = loadedBlock.getRawFieldBlocks()[0]; + assertBlockEquals(VARCHAR, varcharBlock, varcharBlock2); + + int[] fieldBlockOffsets = {0, 1}; + long sizes = sizeOf(fieldBlockOffsets); + AtomicBoolean isIdentical = new AtomicBoolean(false); + loadedBlock.retainedBytesForEachPart((part, size) -> { + if (size.equals(sizes)) { + isIdentical.set(true); + } + }); + assertTrue(isIdentical.get()); + + Block rawFieldBlock = loadedBlock.getRawFieldBlocks()[0].getRegion(2, 2); + assertEquals(rawFieldBlock.getPositionCount(), 2); + for (int i = 0; i < rawFieldBlock.getPositionCount(); i++) { + assertEquals(rawFieldBlock.get(i), varcharBlock2.get(i + 2)); + } + + DynamicSliceOutput sliceOutput = new DynamicSliceOutput(1024); + blockEncodingSerde.writeBlock(sliceOutput, baseBlock); + RowBlock actualBlock = (RowBlock) blockEncodingSerde.readBlock(sliceOutput.slice().getInput()); + assertBlockEquals(actualBlock.getRawFieldBlocks()[0], + (VarcharVec) baseBlock.getRawFieldBlocks()[0].getValues()); + + loadedBlock.close(); + rawFieldBlock.close(); + actualBlock.close(); + } + + @Test + public void testMultipleValuesWithNull() + { + BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, 10); + blockBuilder.appendNull(); + VARCHAR.writeString(blockBuilder, "alice"); + blockBuilder.appendNull(); + VARCHAR.writeString(blockBuilder, "bob"); + Block block = OperatorUtils.buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, blockBuilder.build()); + + assertTrue(block.isNull(0)); + assertEquals(VARCHAR.getObjectValue(SESSION, block, 1), "alice"); + assertTrue(block.isNull(2)); + assertEquals(VARCHAR.getObjectValue(SESSION, block, 3), "bob"); + + // build block from vec + Block block1 = new VariableWidthOmniBlock(4, (VarcharVec) block.getValues()); + assertTrue(block1.isNull(0)); + assertEquals(VARCHAR.getObjectValue(SESSION, block1, 1), "alice"); + assertTrue(block1.isNull(2)); + assertEquals(VARCHAR.getObjectValue(SESSION, block1, 3), "bob"); + block.close(); + } + + @Test + public void testInvalidInput() + { + byte[] rowIsNull = {}; + int[] fieldBlockOffsets = {}; + Block[] fieldBlocks = {}; + assertThatThrownBy(() -> new RowOmniBlock(-1, 4, rowIsNull, fieldBlockOffsets, fieldBlocks, DataType.INVALID)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("Number of fields in RowBlock must be positive"); + } + + private Block buildBlockByBuilder() + { + BlockBuilder rowBlockBuilder = VARCHAR.createBlockBuilder(null, 4); + VARCHAR.writeString(rowBlockBuilder, "alice"); + VARCHAR.writeString(rowBlockBuilder, "bob"); + VARCHAR.writeString(rowBlockBuilder, "charlie"); + VARCHAR.writeString(rowBlockBuilder, "dave"); + Block varCharBlock = OperatorUtils.buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + rowBlockBuilder.build()); + + byte[] rowIsNull = new byte[1]; + int[] fieldBlockOffsets = {0, 1}; + Block[] blocks = new Block[1]; + blocks[0] = varCharBlock; + + return new RowOmniBlock(0, 1, rowIsNull, fieldBlockOffsets, blocks, VarcharDataType.VARCHAR); + } + + private static void assertBlockEquals(Block actual, VarcharVec expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(new String((byte[]) actual.get(position)), new String(expected.get(position))); + } + } + + private static void assertBlockEquals(Type type, Block actual, Block expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(type.getObjectValue(SESSION, actual, position), + type.getObjectValue(SESSION, expected, position)); + } + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestVariableWidthOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestVariableWidthOmniBlock.java new file mode 100644 index 0000000000000000000000000000000000000000..4296a8c8e52404c40f9f6dba78d339d375b353e3 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/block/TestVariableWidthOmniBlock.java @@ -0,0 +1,332 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.block; + +import io.airlift.slice.DynamicSliceOutput; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.prestosql.metadata.InternalBlockEncodingSerde; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.BlockEncodingSerde; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.util.BloomFilter; +import nova.hetu.olk.tool.OperatorUtils; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.Test; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.atomic.AtomicBoolean; + +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.spi.block.TestingSession.SESSION; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestVariableWidthOmniBlock +{ + private final BlockEncodingSerde blockEncodingSerde = new InternalBlockEncodingSerde( + createTestMetadataManager().getFunctionAndTypeManager()); + + @Test + public void testBasicFunc() + { + // build vec through vec + Block baseBlock = buildBlockByBuilder(); + VariableWidthOmniBlock variableWidthOmniBlock = new VariableWidthOmniBlock(baseBlock.getPositionCount(), + (VarcharVec) baseBlock.getValues()); + assertBlockEquals(VARCHAR, variableWidthOmniBlock, baseBlock); + assertEquals(baseBlock.toString(), variableWidthOmniBlock.toString()); + + AtomicBoolean isIdentical = new AtomicBoolean(false); + variableWidthOmniBlock.retainedBytesForEachPart((part, size) -> { + if (size == ((VarcharVec) baseBlock.getValues()).getCapacityInBytes()) { + isIdentical.set(true); + } + }); + assertTrue(isIdentical.get()); + + variableWidthOmniBlock.setClosable(true); + + Block regionOmniBlock = variableWidthOmniBlock.getRegion(2, 2); + assertEquals(regionOmniBlock.getPositionCount(), 2); + + for (int i = 0; i < regionOmniBlock.getPositionCount(); i++) { + assertEquals(regionOmniBlock.get(i), variableWidthOmniBlock.get(i + 2)); + } + + DynamicSliceOutput sliceOutput = new DynamicSliceOutput(1024); + blockEncodingSerde.writeBlock(sliceOutput, baseBlock); + Block actualBlock = blockEncodingSerde.readBlock(sliceOutput.slice().getInput()); + assertBlockEquals(actualBlock, (VarcharVec) baseBlock.getValues()); + + variableWidthOmniBlock.close(); + regionOmniBlock.close(); + actualBlock.close(); + } + + @Test + public void testCopyRegion() + { + Block baseBlock = buildBlockByBuilder(); + Block variableWidthOmniBlock = new VariableWidthOmniBlock(baseBlock.getPositionCount(), + (VarcharVec) baseBlock.getValues()); + Block copyRegionBlock = variableWidthOmniBlock.copyRegion(0, variableWidthOmniBlock.getPositionCount()); + assertBlockEquals(copyRegionBlock, (VarcharVec) variableWidthOmniBlock.getValues()); + + Block copyNotEqualRegionBlock = variableWidthOmniBlock.copyRegion(0, 3); + assertBlockEquals(copyNotEqualRegionBlock, (VarcharVec) variableWidthOmniBlock.getValues()); + + copyNotEqualRegionBlock.close(); + } + + @Test + public void testCopyPosition() + { + Block baseBlock = buildBlockByBuilder(); + Block variableWidthOmniBlock = new VariableWidthOmniBlock(baseBlock.getPositionCount(), + (VarcharVec) baseBlock.getValues()); + + int[] positions = {0, 2, 3}; + Block copyPositionsBlock = variableWidthOmniBlock.copyPositions(positions, 0, 3); + for (int i = 0; i < 3; i++) { + assertEquals(copyPositionsBlock.getString(i, 0, 0), variableWidthOmniBlock.getString(positions[i], 0, 0)); + } + variableWidthOmniBlock.close(); + copyPositionsBlock.close(); + } + + @Test + public void testVarcharVecWithLastValueIsNull() + { + int position = 5; + String[] strs = new String[]{"alice", "bob", "charlie"}; + StringBuilder builder = new StringBuilder(); + for (String data : strs) { + builder.append(data); + } + int[] offset = new int[]{0, 5, 8, 15}; + VarcharVec values = new VarcharVec(1024, position); + values.put(0, builder.toString().getBytes(StandardCharsets.UTF_8), 0, offset, 0, 3); + values.setNull(3); + values.setNull(4); + VariableWidthOmniBlock block = new VariableWidthOmniBlock(position, values); + int totalLen = 0; + for (int i = 0; i < position; i++) { + totalLen += block.getSliceLength(i); + } + assertEquals(totalLen, 15); + totalLen = 0; + VariableWidthOmniBlock variableWidthOmniBlock = new VariableWidthOmniBlock(3, values.slice(2, 5)); + for (int i = 0; i < 3; i++) { + totalLen += variableWidthOmniBlock.getSliceLength(i); + } + assertEquals(totalLen, 7); + + block.close(); + variableWidthOmniBlock.close(); + } + + @Test + public void testFilter() + { + int count = 1024; + int size = 1000; + boolean[] valid = new boolean[count]; + Arrays.fill(valid, Boolean.TRUE); + VariableWidthOmniBlock block = getBlock(count); + String[] values = new String[block.getPositionCount()]; + + BloomFilter bf = getBf(size); + for (int i = 0; i < block.getPositionCount(); i++) { + values[i] = block.getString(i, 0, 0); + } + + boolean[] actualValidPositions = block.filter(bf, valid); + assertEquals(actualValidPositions, valid); + + int[] positions = {0, 1, 2, 3}; + int positionCount = 4; + int[] matchedPosition = new int[4]; + int actualFilterPositions = block.filter(positions, positionCount, matchedPosition, (x) -> { + return true; + }); + assertEquals(actualFilterPositions, positionCount); + block.close(); + } + + @Test + public void testMultipleValuesWithNull() + { + BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, 10); + blockBuilder.appendNull(); + VARCHAR.writeString(blockBuilder, "alice"); + blockBuilder.appendNull(); + VARCHAR.writeString(blockBuilder, "bob"); + Block block = OperatorUtils.buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, blockBuilder.build()); + + assertTrue(block.isNull(0)); + assertEquals(VARCHAR.getObjectValue(SESSION, block, 1), "alice"); + assertTrue(block.isNull(2)); + assertEquals(VARCHAR.getObjectValue(SESSION, block, 3), "bob"); + + // build block from vec + Block block1 = new VariableWidthOmniBlock(4, (VarcharVec) block.getValues()); + assertTrue(block1.isNull(0)); + assertEquals(VARCHAR.getObjectValue(SESSION, block1, 1), "alice"); + assertTrue(block1.isNull(2)); + assertEquals(VARCHAR.getObjectValue(SESSION, block1, 3), "bob"); + block.close(); + } + + @Test + public void testInvalidInput() + { + Block baseBlock = buildBlockByBuilder(); + int[] offsets = {}; + byte[] bytes = {}; + assertThatThrownBy(() -> new VariableWidthOmniBlock(-1, baseBlock.getPositionCount(), + (VarcharVec) baseBlock.getValues(), offsets, bytes)).isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("arrayOffset is negative"); + + int[] offsets2 = {}; + byte[] bytes2 = {}; + assertThatThrownBy( + () -> new VariableWidthOmniBlock(1, -1, (VarcharVec) baseBlock.getValues(), offsets2, bytes2)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("positionCount is negative"); + + int[] offsets2values = {}; + byte[] bytes2values = {}; + assertThatThrownBy(() -> new VariableWidthOmniBlock(1, 2, null, offsets2values, bytes2values)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("values is null"); + + int[] offsets2offsetsLen = {0}; + byte[] bytes2offsetsLen = {}; + assertThatThrownBy( + () -> new VariableWidthOmniBlock(1, 2, (VarcharVec) baseBlock.getValues(), + offsets2offsetsLen, bytes2offsetsLen)).isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("offsets length is less than positionCount"); + + byte[] bytes4 = {0}; + assertThatThrownBy(() -> new VariableWidthOmniBlock(1, 4, (VarcharVec) baseBlock.getValues(), null, bytes4)) + .isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("offsets is null"); + + int[] offsets5 = new int[6]; + byte[] bytes5 = {0}; + assertThatThrownBy( + () -> new VariableWidthOmniBlock(1, 4, (VarcharVec) baseBlock.getValues(), offsets5, bytes5)) + .isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("valueIsNull length is less than positionCount"); + + StringBuilder builder = new StringBuilder(); + Slice slice = Slices.wrappedBuffer(builder.toString().getBytes()); + byte[] bytes2Array = {0}; + assertThatThrownBy(() -> new VariableWidthOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, -1, -1, slice, + offsets, bytes2Array)).isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("arrayOffset is negative"); + + assertThatThrownBy(() -> new VariableWidthOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 0, -1, slice, offsets, + bytes2Array)).isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("positionCount is negative"); + + assertThatThrownBy(() -> new VariableWidthOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 0, 0, null, offsets, + bytes2Array)).isInstanceOfAny(IllegalArgumentException.class).hasMessageMatching("slice is null"); + + int[] offsets2valueIsNullLen = new int[6]; + assertThatThrownBy(() -> new VariableWidthOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, 1, 4, slice, + offsets2valueIsNullLen, bytes2Array)).isInstanceOfAny(IllegalArgumentException.class) + .hasMessageMatching("valueIsNull length is less than positionCount"); + baseBlock.close(); + } + + @Test + public void testGet() + { + Block baseBlock = buildBlockByBuilder(); + Block variableWidthOmniBlock = new VariableWidthOmniBlock(4, (VarcharVec) baseBlock.getValues()); + long expect = 10; + long expectSizeBytes = 39; + long expectStates = 5; + boolean[] position = {true, true, true, true}; + assertEquals(variableWidthOmniBlock.getRegionSizeInBytes(0, 1), expect); + assertEquals(variableWidthOmniBlock.getRegionSizeInBytes(0, 4), expectSizeBytes); + assertEquals(variableWidthOmniBlock.getEstimatedDataSizeForStats(0), expectStates); + assertEquals(variableWidthOmniBlock.getPositionsSizeInBytes(position), expectSizeBytes); + + variableWidthOmniBlock.close(); + } + + private Block buildBlockByBuilder() + { + BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, 4); + VARCHAR.writeString(blockBuilder, "alice"); + VARCHAR.writeString(blockBuilder, "bob"); + VARCHAR.writeString(blockBuilder, "charlie"); + VARCHAR.writeString(blockBuilder, "dave"); + return OperatorUtils.buildOffHeapBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, blockBuilder.build()); + } + + private VariableWidthOmniBlock getBlock(int count) + { + // returns test data + int[] offsets = new int[count + 1]; + int offset = 0; + StringBuilder buffer = new StringBuilder(); + + for (int i = 0; i < count; i++) { + offsets[i + 1] = offset; + String value = "value" + i; + buffer.append(value); + offset += value.getBytes().length; + } + Slice slice = Slices.wrappedBuffer(buffer.toString().getBytes()); + return new VariableWidthOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, count, slice, offsets, + Optional.empty()); + } + + private BloomFilter getBf(int size) + { + Random rnd = new Random(); + + BloomFilter bf = new BloomFilter(size, 0.01); + for (int i = 0; i < 100; i++) { + bf.test(("value" + rnd.nextLong()).getBytes()); + } + return bf; + } + + private static void assertBlockEquals(Block actual, VarcharVec expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(new String((byte[]) actual.get(position)), new String(expected.get(position))); + } + } + + private static void assertBlockEquals(Type type, Block actual, Block expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(type.getObjectValue(SESSION, actual, position), + type.getObjectValue(SESSION, expected, position)); + } + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/e2e/TestExtensionExecutionPlan.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/e2e/TestExtensionExecutionPlan.java new file mode 100644 index 0000000000000000000000000000000000000000..3871d0a895ee44ee6f0fbf3a3db788f31252a164 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/e2e/TestExtensionExecutionPlan.java @@ -0,0 +1,5 @@ +package nova.hetu.olk.e2e; + +public class TestExtensionExecutionPlan +{ +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/LocalStateStoreProviderTest.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/LocalStateStoreProviderTest.java new file mode 100644 index 0000000000000000000000000000000000000000..2fe01ca40e7ea38da2655a3789dd0cb10377e23a --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/LocalStateStoreProviderTest.java @@ -0,0 +1,152 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableMap; +import com.google.inject.Inject; +import io.airlift.log.Logger; +import io.prestosql.metastore.MetaStoreConstants; +import io.prestosql.seedstore.SeedStoreManager; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.classloader.ThreadContextClassLoader; +import io.prestosql.spi.seedstore.SeedStoreSubType; +import io.prestosql.spi.statestore.StateCollection; +import io.prestosql.spi.statestore.StateStore; +import io.prestosql.spi.statestore.StateStoreFactory; +import io.prestosql.statestore.LocalStateStoreProvider; +import io.prestosql.statestore.StateStoreConstants; +import io.prestosql.statestore.StateStoreProvider; + +import java.io.File; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static com.google.common.base.Preconditions.checkState; +import static io.airlift.configuration.ConfigurationLoader.loadPropertiesFrom; +import static io.prestosql.spi.StandardErrorCode.STATE_STORE_FAILURE; +import static io.prestosql.statestore.StateStoreConstants.STATE_STORE_CONFIGURATION_PATH; +import static io.prestosql.statestore.StateStoreConstants.STATE_STORE_NAME_PROPERTY_NAME; +import static io.prestosql.statestore.StateStoreConstants.STATE_STORE_TYPE_PROPERTY_NAME; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class LocalStateStoreProviderTest + implements StateStoreProvider +{ + private static final Logger log = Logger.get(LocalStateStoreProvider.class); + private static final File STATE_STORE_CONFIGURATION = new File(STATE_STORE_CONFIGURATION_PATH); + private static final String DEFAULT_STATE_STORE_NAME = "default-state-store"; + private static final long SLEEP_INTERVAL = 2000L; + + private final Map stateStoreFactories = new ConcurrentHashMap<>(); + private StateStore stateStore; + private final SeedStoreManager seedStoreManager; + + @Inject + public LocalStateStoreProviderTest(SeedStoreManager seedStoreManager) + { + this.seedStoreManager = requireNonNull(seedStoreManager, "seedStoreManager is null"); + } + + @Override + public void addStateStoreFactory(StateStoreFactory factory) + { + if (stateStoreFactories.putIfAbsent(factory.getName(), factory) != null) { + throw new IllegalArgumentException(format("State Store '%s' is already registered", factory.getName())); + } + } + + @Override + public void loadStateStore() + throws Exception + { + if (STATE_STORE_CONFIGURATION.exists()) { + Map properties = new HashMap<>(loadPropertiesFrom(STATE_STORE_CONFIGURATION.getPath())); + String stateStoreType = properties.remove(STATE_STORE_TYPE_PROPERTY_NAME); + setStateStore(stateStoreType, properties); + createStateCollections(); + } + else { + log.info("No configuration file found, skip loading state store client"); + } + } + + public void setStateStore(String stateStoreType, Map properties) + { + requireNonNull(stateStoreType, "stateStoreType is null"); + requireNonNull(properties, "properties is null"); + + log.info("-- Loading state store --"); + StateStoreFactory stateStoreFactory = stateStoreFactories.get(stateStoreType); + checkState(stateStoreFactory != null, "State store %s is not registered", stateStoreType); + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(stateStoreFactory.getClass().getClassLoader())) { + String stateStoreName = properties.remove(STATE_STORE_NAME_PROPERTY_NAME); + if (stateStoreName == null) { + log.info("State store name not provided, using default state store name: %s", DEFAULT_STATE_STORE_NAME); + stateStoreName = DEFAULT_STATE_STORE_NAME; + } + // Create state stores defined in config + stateStore = stateStoreFactory.create(stateStoreName, seedStoreManager.getSeedStore(SeedStoreSubType.HAZELCAST), ImmutableMap.copyOf(properties)); + stateStore.registerClusterFailureHandler(this::handleClusterDisconnection); + stateStore.init(); + } + catch (Exception e) { + throw new PrestoException(STATE_STORE_FAILURE, "Unable to create state store: " + e.getMessage()); + } + log.info("-- Loaded state store %s --", stateStoreType); + } + + @Override + public StateStore getStateStore() + { + return stateStore; + } + + public void createStateCollections() + { + // Create essential state collections + stateStore.createStateCollection(StateStoreConstants.DISCOVERY_SERVICE_COLLECTION_NAME, StateCollection.Type.MAP); + stateStore.createStateCollection(StateStoreConstants.QUERY_STATE_COLLECTION_NAME, StateCollection.Type.MAP); + stateStore.createStateCollection(StateStoreConstants.FINISHED_QUERY_STATE_COLLECTION_NAME, StateCollection.Type.MAP); + stateStore.createStateCollection(StateStoreConstants.OOM_QUERY_STATE_COLLECTION_NAME, StateCollection.Type.MAP); + stateStore.createStateCollection(StateStoreConstants.CPU_USAGE_STATE_COLLECTION_NAME, StateCollection.Type.MAP); + stateStore.createStateCollection(StateStoreConstants.TRANSACTION_STATE_COLLECTION_NAME, StateCollection.Type.MAP); + + stateStore.createStateCollection(MetaStoreConstants.HETU_META_STORE_CATALOGCACHE_NAME, StateCollection.Type.MAP); + stateStore.createStateCollection(MetaStoreConstants.HETU_META_STORE_CATALOGSCACHE_NAME, StateCollection.Type.MAP); + stateStore.createStateCollection(MetaStoreConstants.HETU_META_STORE_TABLECACHE_NAME, StateCollection.Type.MAP); + stateStore.createStateCollection(MetaStoreConstants.HETU_META_STORE_TABLESCACHE_NAME, StateCollection.Type.MAP); + stateStore.createStateCollection(MetaStoreConstants.HETU_META_STORE_DATABASECACHE_NAME, StateCollection.Type.MAP); + stateStore.createStateCollection(MetaStoreConstants.HETU_META_STORE_DATABASESCACHE_NAME, StateCollection.Type.MAP); + } + + void handleClusterDisconnection(Object obj) + { + log.info("Connection to Hazelcast state store has SHUTDOWN."); + while (true) { + try { + Thread.sleep(SLEEP_INTERVAL); + seedStoreManager.loadSeedStore(); + loadStateStore(); + break; + } + catch (Exception ex) { + log.info("Failed to reload state store: %s", ex.getMessage()); + } + } + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestAggregationOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestAggregationOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..92859119d8be190433e4beeca40e6c6278a853d5 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestAggregationOmniOperator.java @@ -0,0 +1,305 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import io.prestosql.operator.DriverContext; +import io.prestosql.spi.Page; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.RowBlockBuilder; +import io.prestosql.spi.plan.AggregationNode; +import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.type.RowType; +import io.prestosql.spi.type.Type; +import io.prestosql.testing.MaterializedResult; +import nova.hetu.olk.operator.AggregationOmniOperator.AggregationOmniOperatorFactory; +import nova.hetu.olk.tool.OperatorUtils; +import nova.hetu.omniruntime.constants.FunctionType; +import nova.hetu.omniruntime.type.BooleanDataType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; +import java.util.stream.Collectors; + +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.airlift.slice.Slices.utf8Slice; +import static io.prestosql.RowPagesBuilder.rowPagesBuilder; +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.operator.OperatorAssertion.assertOperatorEquals; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static io.prestosql.testing.MaterializedResult.resultBuilder; +import static io.prestosql.testing.TestingTaskContext.createTaskContext; +import static java.lang.String.format; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static nova.hetu.olk.tool.OperatorUtils.transferToOffHeapPages; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_AVG; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_COLUMN; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_MAX; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_SUM; +import static org.testng.Assert.assertEquals; + +@Test(singleThreaded = true) +public class TestAggregationOmniOperator +{ + private ExecutorService executor; + + private ScheduledExecutorService scheduledExecutor; + + @BeforeMethod + public void setUp() + { + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + } + + @DataProvider(name = "hashEnabled") + public static Object[][] hashEnabled() + { + return new Object[][]{{true}, {false}}; + } + + @AfterMethod(alwaysRun = true) + public void tearDown() + { + executor.shutdownNow(); + scheduledExecutor.shutdownNow(); + } + + @Test(invocationCount = 1) + public void testAggregation() + { + List input = rowPagesBuilder(BIGINT, BIGINT, BIGINT, VARCHAR).addSequencePage(100, 0, 0, 0, 300).build(); + + int id = 0; + List inputTypes = ImmutableList.of(BIGINT, BIGINT, BIGINT, VARCHAR); + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, new VarcharDataType(10)}; + FunctionType[] aggregatorTypes = {OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_AVG, + OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_MAX}; + int[] aggregationInputChannels = {0, 1, 2, 3}; + DataType[] aggregationOutputTypes = {LongDataType.LONG, DoubleDataType.DOUBLE, LongDataType.LONG, new VarcharDataType(10)}; + AggregationNode.Step step = AggregationNode.Step.SINGLE; + ImmutableList.Builder> maskChannels = new ImmutableList.Builder<>(); + for (int i = 0; i < aggregatorTypes.length; i++) { + maskChannels.add(Optional.empty()); + } + AggregationOmniOperatorFactory aggregationOmniOperatorFactory = new AggregationOmniOperatorFactory(id, + new PlanNodeId(String.valueOf(id)), inputTypes, aggregatorTypes, aggregationInputChannels, + maskChannels.build(), aggregationOutputTypes, step); + DriverContext driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, DOUBLE, BIGINT, VARCHAR) + .row(100L, 49.5, 4950L, "399").build(); + assertOperatorEquals(aggregationOmniOperatorFactory, driverContext, input, expected); + assertEquals(driverContext.getSystemMemoryUsage(), 0); + assertEquals(driverContext.getMemoryUsage(), 0); + } + + @Test(invocationCount = 1) + public void testCountAggregationCompare() + { + List types = ImmutableList.of(BIGINT, BIGINT, INTEGER); + List input = rowPagesBuilder(types).row(1, 1, null).row(null, 2, 2).row(null, 3, 3).row(4, 4, 4) + .row(5, null, 5).row(null, 6, 6).build(); + + int id = 0; + List inputTypes = ImmutableList.of(BIGINT, INTEGER); + DataType[] sourceTypes = {LongDataType.LONG, IntDataType.INTEGER}; + FunctionType[] aggregatorTypes = {OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_COUNT_ALL, + OMNI_AGGREGATION_TYPE_COUNT_COLUMN}; + int[] aggregationInputChannels = {0, 2}; + DataType[] aggregationOutputTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}; + AggregationNode.Step step = AggregationNode.Step.SINGLE; + ImmutableList.Builder> maskChannels = new ImmutableList.Builder<>(); + for (int i = 0; i < aggregatorTypes.length; i++) { + maskChannels.add(Optional.empty()); + } + AggregationOmniOperatorFactory aggregationOmniOperatorFactory = new AggregationOmniOperatorFactory(id, + new PlanNodeId(String.valueOf(id)), inputTypes, aggregatorTypes, aggregationInputChannels, + maskChannels.build(), aggregationOutputTypes, step); + DriverContext driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, BIGINT, BIGINT).row(3L, 6L, 5L) + .build(); + assertOperatorEquals(aggregationOmniOperatorFactory, driverContext, input, expected); + assertEquals(driverContext.getSystemMemoryUsage(), 0); + assertEquals(driverContext.getMemoryUsage(), 0); + } + + @Test(invocationCount = 1) + public void testCountAggregation() + { + List input = rowPagesBuilder(BIGINT, BIGINT, BOOLEAN, BOOLEAN).row(10L, 20L, true, true) + .row(20L, 10L, true, true).pageBreak().row(10L, 30L, false, true).row(30L, 10L, true, false).build(); + + // transfer on-heap page to off-heap + List offHeapInput = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + int id = 0; + List inputTypes = ImmutableList.of(BIGINT, BIGINT, BOOLEAN, BOOLEAN); + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG, BooleanDataType.BOOLEAN, BooleanDataType.BOOLEAN}; + FunctionType[] aggregatorTypes = {OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_SUM}; + int[] aggregationInputChannels = {0, 1}; + DataType[] aggregationOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + AggregationNode.Step step = AggregationNode.Step.SINGLE; + ImmutableList.Builder> maskChannels = new ImmutableList.Builder<>(); + maskChannels.add(Optional.of(2)); + maskChannels.add(Optional.of(3)); + + AggregationOmniOperatorFactory aggregationOmniOperatorFactory = new AggregationOmniOperatorFactory(id, + new PlanNodeId(String.valueOf(id)), inputTypes, aggregatorTypes, aggregationInputChannels, + maskChannels.build(), aggregationOutputTypes, step); + DriverContext driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, BIGINT).row(3L, 60L).build(); + assertOperatorEquals(aggregationOmniOperatorFactory, driverContext, offHeapInput, expected); + assertEquals(driverContext.getSystemMemoryUsage(), 0); + assertEquals(driverContext.getMemoryUsage(), 0); + } + + @Test(invocationCount = 1) + public void testAggregationWithRowBlock() + { + List fieldTypes = ImmutableList.of(VARCHAR, BIGINT); + List[] testRows = generateTestRows(fieldTypes, 100); + + BlockBuilder blockBuilder = createBlockBuilderWithValues(fieldTypes, testRows); + Block block = blockBuilder.build(); + Block[] blocks = new Block[1]; + blocks[0] = block; + + Page page1 = new Page(blocks); + Page page2 = new Page(blocks); + List input = new ArrayList<>(); + input.add(page1); + input.add(page2); + + List fields = new ArrayList<>(); + for (int i = 0; i < fieldTypes.size(); i++) { + fields.add(new RowType.Field(Optional.of(i + ""), fieldTypes.get(i))); + } + RowType c = RowType.from(fields); + // transfer on-heap page to off-heap + List offHeapPages = input.stream().map(var -> + OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, var, + Arrays.stream(var.getBlocks()).map(val -> c).collect(Collectors.toList()))) + .collect(Collectors.toList()); + + AggregationNode.Step step = AggregationNode.Step.PARTIAL; + int id = 0; + List inputTypes = ImmutableList.of(BIGINT, BIGINT, BIGINT, VARCHAR); + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, new VarcharDataType(10)}; + FunctionType[] aggregatorTypes = {OMNI_AGGREGATION_TYPE_COUNT_COLUMN}; + int[] aggregationInputChannels = {0, 1, 2, 3}; + DataType[] aggregationOutputTypes = {LongDataType.LONG}; + ImmutableList.Builder> maskChannels = new ImmutableList.Builder<>(); + for (int i = 0; i < aggregatorTypes.length; i++) { + maskChannels.add(Optional.empty()); + } + AggregationOmniOperatorFactory aggregationOmniOperatorFactory = new AggregationOmniOperatorFactory(id, + new PlanNodeId(String.valueOf(id)), inputTypes, aggregatorTypes, aggregationInputChannels, + maskChannels.build(), aggregationOutputTypes, step); + DriverContext driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT).row(200L).build(); + assertOperatorEquals(aggregationOmniOperatorFactory, driverContext, offHeapPages, expected); + assertEquals(driverContext.getSystemMemoryUsage(), 0); + assertEquals(driverContext.getMemoryUsage(), 0); + } + + private List[] generateTestRows(List fieldTypes, int numRows) + { + List[] testRows = new List[numRows]; + for (int i = 0; i < numRows; i++) { + List testRow = new ArrayList<>(fieldTypes.size()); + for (int j = 0; j < fieldTypes.size(); j++) { + int cellId = i * fieldTypes.size() + j; + if (cellId % 7 == 3) { + // put null value for every 7 cells + testRow.add(null); + } + else { + if (fieldTypes.get(j) == BIGINT) { + testRow.add(i * 100L + j); + } + else if (fieldTypes.get(j) == VARCHAR) { + testRow.add(format("field(%s, %s)", i, j)); + } + else { + throw new IllegalArgumentException(); + } + } + } + testRows[i] = testRow; + } + return testRows; + } + + private BlockBuilder createBlockBuilderWithValues(List fieldTypes, List[] rows) + { + BlockBuilder rowBlockBuilder = new RowBlockBuilder(fieldTypes, null, 1); + for (List row : rows) { + if (row == null) { + rowBlockBuilder.appendNull(); + } + else { + BlockBuilder singleRowBlockWriter = rowBlockBuilder.beginBlockEntry(); + for (Object fieldValue : row) { + if (fieldValue == null) { + singleRowBlockWriter.appendNull(); + } + else { + if (fieldValue instanceof Long) { + BIGINT.writeLong(singleRowBlockWriter, ((Long) fieldValue).longValue()); + } + else if (fieldValue instanceof String) { + VARCHAR.writeSlice(singleRowBlockWriter, utf8Slice((String) fieldValue)); + } + else { + throw new IllegalArgumentException(); + } + } + } + rowBlockBuilder.closeEntry(); + } + } + return rowBlockBuilder; + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestBuildOffHeapOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestBuildOffHeapOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..693646e966457b2124dab07b61820ee01ebd5647 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestBuildOffHeapOmniOperator.java @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import io.prestosql.operator.DriverContext; +import io.prestosql.operator.OperatorFactory; +import io.prestosql.spi.Page; +import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.type.Type; +import io.prestosql.testing.MaterializedResult; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; + +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.operator.OperatorAssertion.assertOperatorEquals; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.DateType.DATE; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static io.prestosql.testing.MaterializedResult.resultBuilder; +import static io.prestosql.testing.TestingTaskContext.createTaskContext; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static nova.hetu.olk.tool.OperatorUtils.transferToOffHeapPages; +import static nova.hetu.olk.tool.TestOperatorUtils.buildPages; +import static nova.hetu.olk.tool.TestOperatorUtils.freeNativeMemory; + +@Test(singleThreaded = true) +public class TestBuildOffHeapOmniOperator +{ + private ExecutorService executor; + private ScheduledExecutorService scheduledExecutor; + private DriverContext driverContext; + + List types = new ImmutableList.Builder().add(INTEGER).add(BIGINT).add(DOUBLE).add(BOOLEAN).add(VARCHAR) + .add(DATE).build(); + + @BeforeMethod + public void setUp() + { + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + } + + @AfterMethod(alwaysRun = true) + public void tearDown() + { + executor.shutdownNow(); + scheduledExecutor.shutdownNow(); + } + + @Test(enabled = true) + public void testFunctionWithoutDictionary() + { + List pages = buildPages(types, false, 100); + + OperatorFactory operatorFactory = new BuildOffHeapOmniOperator.BuildOffHeapOmniOperatorFactory(0, + new PlanNodeId("test"), types); + DriverContext driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + List expectedPages = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, pages); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), types).pages(expectedPages).build(); + assertOperatorEquals(operatorFactory, driverContext, pages, expected); + + freeNativeMemory(expectedPages); + } + + @Test(enabled = true) + public void testFunctionWithDictionary() + { + List pages = buildPages(types, true, 100); + + OperatorFactory operatorFactory = new BuildOffHeapOmniOperator.BuildOffHeapOmniOperatorFactory(0, + new PlanNodeId("test"), types); + + DriverContext driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + List expectedPages = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, pages); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), types).pages(expectedPages).build(); + assertOperatorEquals(operatorFactory, driverContext, pages, expected); + + freeNativeMemory(expectedPages); + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestBuildOnHeapOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestBuildOnHeapOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..2dbd0e1e03caa68ff0e4cbe2c14161ed6f2796c0 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestBuildOnHeapOmniOperator.java @@ -0,0 +1,107 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import io.prestosql.operator.DriverContext; +import io.prestosql.operator.OperatorFactory; +import io.prestosql.spi.Page; +import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.type.Type; +import io.prestosql.testing.MaterializedResult; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; + +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.operator.OperatorAssertion.assertOperatorEquals; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.DateType.DATE; +import static io.prestosql.spi.type.DecimalType.createDecimalType; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.RealType.REAL; +import static io.prestosql.spi.type.TimestampType.TIMESTAMP; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static io.prestosql.testing.MaterializedResult.resultBuilder; +import static io.prestosql.testing.TestingTaskContext.createTaskContext; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static nova.hetu.olk.tool.OperatorUtils.transferToOffHeapPages; +import static nova.hetu.olk.tool.TestOperatorUtils.buildPages; + +@Test(singleThreaded = true) +public class TestBuildOnHeapOmniOperator +{ + private ExecutorService executor; + private ScheduledExecutorService scheduledExecutor; + + List types = new ImmutableList.Builder().add(INTEGER).add(BIGINT).add(REAL).add(DOUBLE).add(VARCHAR) + .add(DATE).add(TIMESTAMP).add(BOOLEAN).add(createDecimalType(20, 10)).build(); + + @BeforeMethod + public void setUp() + { + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + } + + @AfterMethod(alwaysRun = true) + public void tearDown() + { + executor.shutdownNow(); + scheduledExecutor.shutdownNow(); + } + + @Test(enabled = true) + public void testFunctionWithoutDictionary() + { + List pages = buildPages(types, false, 100); + // transfer on-heap page to off-heap + List offHeapInput = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, pages); + + OperatorFactory operatorFactory = new BuildOnHeapOmniOperator.BuildOnHeapOmniOperatorFactory(0, + new PlanNodeId("test")); + DriverContext driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), types).pages(pages).build(); + assertOperatorEquals(operatorFactory, driverContext, offHeapInput, expected); + } + + @Test(enabled = true) + public void testFunctionWithDictionary() + { + List pages = buildPages(types, true, 100); + // transfer on-heap page to off-heap + List offHeapInput = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, pages); + + OperatorFactory operatorFactory = new BuildOnHeapOmniOperator.BuildOnHeapOmniOperatorFactory(0, + new PlanNodeId("test")); + DriverContext driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), types).pages(pages).build(); + assertOperatorEquals(operatorFactory, driverContext, offHeapInput, expected); + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestCompareFunction.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestCompareFunction.java new file mode 100644 index 0000000000000000000000000000000000000000..ff3e7c21671729a47db288d8280499651651e2fb --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestCompareFunction.java @@ -0,0 +1,124 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import io.airlift.slice.Slice; +import io.prestosql.spi.block.VariableWidthBlock; +import nova.hetu.olk.BlockUtil; +import org.testng.annotations.Test; + +public class TestCompareFunction +{ + int rowSize = 1000000; + int width = 10; + int round = 10; + + VariableWidthBlock variableWidthBlock1 = (VariableWidthBlock) BlockUtil.buildVarcharBlock(rowSize, width, 10); + VariableWidthBlock variableWidthBlock2 = (VariableWidthBlock) variableWidthBlock1.copyRegion(0, + variableWidthBlock1.getPositionCount()); + VariableWidthBlock variableWidthBlock3 = (VariableWidthBlock) BlockUtil.buildVarcharBlock(rowSize, width, 20); + Slice[] slice1 = BlockUtil.getBlockSlices(variableWidthBlock1, rowSize, width); + Slice[] slice2 = BlockUtil.getBlockSlices(variableWidthBlock2, rowSize, width); + Slice[] slice3 = BlockUtil.getBlockSlices(variableWidthBlock3, rowSize, width); + + @Test + public void testSameSliceComparePerf() + { + int comp = 0; + double sum = 0; + + for (int j = 0; j < round; j++) { + long startTime = System.nanoTime(); + + for (int i = 0; i < rowSize; i++) { + comp = slice1[i].compareTo(slice2[i]); + } + + long endTime = System.nanoTime(); + long duration = (endTime - startTime) / 1000_000; + System.out.println("Round: " + (j + 1) + " time: " + duration + " ms"); + sum += duration; + } + System.out.println("comp: " + comp); + System.out.println("avg time: " + sum / round + " ms"); + } + + @Test + public void testDiffSliceComparePerf() + { + int comp = 0; + double sum = 0; + + for (int j = 0; j < round; j++) { + long startTime = System.nanoTime(); + + for (int i = 0; i < rowSize; i++) { + comp = slice1[i].compareTo(slice3[i]); + } + + long endTime = System.nanoTime(); + long duration = (endTime - startTime) / 1000_000; + System.out.println("Round: " + (j + 1) + " time: " + duration + " ms"); + sum += duration; + } + System.out.println("comp: " + comp); + System.out.println("avg time: " + sum / round + " ms"); + } + + @Test + public void testSameBlockComparePerf() + { + int comp = 0; + double sum = 0; + + for (int j = 0; j < round; j++) { + long startTime = System.nanoTime(); + + for (int i = 0; i < rowSize; i++) { + comp = variableWidthBlock1.compareTo(i, 0, width, variableWidthBlock2, i, 0, width); + } + + long endTime = System.nanoTime(); + long duration = (endTime - startTime) / 1000_000; + System.out.println("Round: " + (j + 1) + " time: " + duration + " ms"); + sum += duration; + } + System.out.println("comp: " + comp); + System.out.println("avg time: " + sum / round + " ms"); + } + + @Test + public void testDiffBlockComparePerf() + { + int comp = 0; + double sum = 0; + + for (int j = 0; j < round; j++) { + long startTime = System.nanoTime(); + + for (int i = 0; i < rowSize; i++) { + comp = variableWidthBlock1.compareTo(i, 0, width, variableWidthBlock3, i, 0, width); + } + + long endTime = System.nanoTime(); + long duration = (endTime - startTime) / 1000_000; + System.out.println("Round: " + (j + 1) + " time: " + duration + " ms"); + sum += duration; + } + System.out.println("comp: " + comp); + System.out.println("avg time: " + sum / round + " ms"); + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestDistinctLimitOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestDistinctLimitOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..815fc9ccffd394656af5b59aab8dacc01426737e --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestDistinctLimitOmniOperator.java @@ -0,0 +1,153 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import io.prestosql.operator.DriverContext; +import io.prestosql.operator.OperatorFactory; +import io.prestosql.spi.Page; +import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.type.Type; +import io.prestosql.testing.MaterializedResult; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; + +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.prestosql.RowPagesBuilder.rowPagesBuilder; +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.operator.OperatorAssertion.assertOperatorEquals; +import static io.prestosql.operator.OperatorAssertion.toMaterializedResult; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.CharType.createCharType; +import static io.prestosql.spi.type.DecimalType.createDecimalType; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static io.prestosql.testing.MaterializedResult.resultBuilder; +import static io.prestosql.testing.TestingTaskContext.createTaskContext; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static nova.hetu.olk.tool.OperatorUtils.transferToOffHeapPages; + +@Test(singleThreaded = true) +public class TestDistinctLimitOmniOperator +{ + private ExecutorService executor; + private ScheduledExecutorService scheduledExecutor; + private DriverContext driverContext; + + @BeforeMethod + public void setUp() + { + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + } + + @AfterMethod(alwaysRun = true) + public void tearDown() + { + executor.shutdownNow(); + scheduledExecutor.shutdownNow(); + } + + @Test(enabled = true) + public void testDistinctLimitColTypesCover() + { + List input = rowPagesBuilder(BIGINT, INTEGER, createCharType(10), INTEGER, DOUBLE, BOOLEAN, VARCHAR, + createDecimalType(10, 2)).row(1000L, 3, "aaa", 0, 6.6, true, "hello", 1001) + .row(2000L, 4, "bbb", 1, 5.5, false, "world", 2002).pageBreak() + .row(1000L, 5, "aaa", 0, 6.6, true, "hello", 1001) + .row(3000L, 6, "ccc", 2, 4.4, false, "welcome", 3003).build(); + + // transfer on-heap page to off-heap + List offHeapInput = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + OperatorFactory operatorFactory = new DistinctLimitOmniOperator.DistinctLimitOmniOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(BIGINT, INTEGER, createCharType(10), INTEGER, DOUBLE, + BOOLEAN, VARCHAR, createDecimalType(10, 2)), ImmutableList.of(2, 3, 4, 5, 6, 7), Optional.of(0), 100); + + List expectedTypes = ImmutableList.of(createCharType(10), INTEGER, DOUBLE, BOOLEAN, VARCHAR, + createDecimalType(10, 2), BIGINT); + + List output = rowPagesBuilder(createCharType(10), INTEGER, DOUBLE, BOOLEAN, VARCHAR, + createDecimalType(10, 2), BIGINT).row("aaa", 0, 6.6, true, "hello", 1001, 1000L) + .row("bbb", 1, 5.5, false, "world", 2002, 2000L) + .row("ccc", 2, 4.4, false, "welcome", 3003, 3000L).build(); + + MaterializedResult expected = toMaterializedResult(driverContext.getSession(), expectedTypes, output); + assertOperatorEquals(operatorFactory, driverContext, offHeapInput, expected); + } + + @Test(enabled = true) + public void testDistinctLimitBasic() + { + List input = rowPagesBuilder(BIGINT, DOUBLE).row(1L, 0.1).row(2L, 0.2).pageBreak().row(1L, 0.1) + .row(-1L, -0.1).row(4L, 0.4).pageBreak().row(5L, 0.5).row(4L, 0.41).row(6L, 0.6).pageBreak().build(); + // transfer on-heap page to off-heap + List offHeapInput = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + OperatorFactory operatorFactory = new DistinctLimitOmniOperator.DistinctLimitOmniOperatorFactory(0, new PlanNodeId("test"), + ImmutableList.of(BIGINT, DOUBLE), ImmutableList.of(0, 1), Optional.empty(), 3); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, DOUBLE).row(1L, 0.1) + .row(2L, 0.2).row(-1L, -0.1).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapInput, expected); + } + + @Test(enabled = true) + public void testDistinctLimitHasNull() + { + List input = rowPagesBuilder(BIGINT, DOUBLE).row(1L, 0.1).row(2L, 0.2).pageBreak().row(null, null) + .row(1L, 0.1).row(4L, 0.4).pageBreak().row(5L, 0.5).row(4L, 0.41).row(6L, 0.6).pageBreak().build(); + // transfer on-heap page to off-heap + List offHeapInput = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + OperatorFactory operatorFactory = new DistinctLimitOmniOperator.DistinctLimitOmniOperatorFactory(0, new PlanNodeId("test"), + ImmutableList.of(BIGINT, DOUBLE), ImmutableList.of(0, 1), Optional.empty(), 4); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, DOUBLE).row(1L, 0.1) + .row(2L, 0.2).row(null, null).row(4L, 0.4).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapInput, expected); + } + + @Test(enabled = true) + public void testDistinctLimitHashCol() + { + List input = rowPagesBuilder(BIGINT, DOUBLE, BIGINT).row(1L, 0.1, 10L).row(2L, 0.2, 20L).pageBreak() + .row(1L, 0.1, 10L).row(-1L, -0.1, -10L).row(4L, 0.4, 40L).pageBreak().row(5L, 0.5, 50L) + .row(4L, 0.41, 60L).row(6L, 0.6, 70L).pageBreak().build(); + // transfer on-heap page to off-heap + List offHeapInput = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + OperatorFactory operatorFactory = new DistinctLimitOmniOperator.DistinctLimitOmniOperatorFactory(0, new PlanNodeId("test"), + ImmutableList.of(BIGINT, DOUBLE, BIGINT), ImmutableList.of(0, 1), Optional.of(2), 3); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, DOUBLE, BIGINT) + .row(1L, 0.1, 10L).row(2L, 0.2, 20L).row(-1L, -0.1, -10L).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapInput, expected); + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestDynamicFilterSourceOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestDynamicFilterSourceOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..94f18922449771410d66d95237c459809adf458d --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestDynamicFilterSourceOmniOperator.java @@ -0,0 +1,191 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multimap; +import com.google.common.collect.MultimapBuilder; +import io.airlift.node.NodeInfo; +import io.hetu.core.statestore.hazelcast.HazelcastStateStoreBootstrapper; +import io.hetu.core.statestore.hazelcast.HazelcastStateStoreFactory; +import io.prestosql.execution.TaskId; +import io.prestosql.operator.DynamicFilterSourceOperator; +import io.prestosql.operator.DynamicFilterSourceOperator.Channel; +import io.prestosql.operator.Operator; +import io.prestosql.operator.OperatorFactory; +import io.prestosql.operator.PipelineContext; +import io.prestosql.seedstore.SeedStoreManager; +import io.prestosql.spi.Page; +import io.prestosql.spi.dynamicfilter.DynamicFilter; +import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.plan.Symbol; +import io.prestosql.spi.seedstore.Seed; +import io.prestosql.spi.seedstore.SeedStore; +import io.prestosql.spi.seedstore.SeedStoreSubType; +import io.prestosql.spi.statestore.StateStore; +import io.prestosql.spi.statestore.StateStoreBootstrapper; +import io.prestosql.spi.statestore.StateStoreFactory; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.analyzer.FeaturesConfig; +import io.prestosql.sql.planner.LocalDynamicFilter; +import io.prestosql.statestore.StateStoreProvider; +import io.prestosql.testing.MaterializedResult; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; + +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.hetu.core.statestore.hazelcast.HazelcastConstants.DISCOVERY_PORT_CONFIG_NAME; +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.SystemSessionProperties.getDynamicFilteringMaxPerDriverSize; +import static io.prestosql.SystemSessionProperties.getDynamicFilteringMaxPerDriverValueCount; +import static io.prestosql.block.BlockAssertions.createLongsBlock; +import static io.prestosql.operator.OperatorAssertion.toMaterializedResult; +import static io.prestosql.operator.OperatorAssertion.toPages; +import static io.prestosql.spi.dynamicfilter.DynamicFilter.Type.LOCAL; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.sql.analyzer.FeaturesConfig.DynamicFilterDataType.HASHSET; +import static io.prestosql.testing.TestingTaskContext.createTaskContext; +import static io.prestosql.testing.assertions.Assert.assertEquals; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static java.util.stream.Collectors.toList; +import static nova.hetu.olk.tool.OperatorUtils.transferToOffHeapPages; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@Test(singleThreaded = true) +public class TestDynamicFilterSourceOmniOperator +{ + private ExecutorService executor; + private ScheduledExecutorService scheduledExecutor; + private PipelineContext pipelineContext; + private StateStoreProvider stateStoreProvider; + + @BeforeTest + private void prepareConfigFiles() + throws Exception + { + Set seeds = new HashSet<>(); + SeedStore mockSeedStore = mock(SeedStore.class); + Seed mockSeed = mock(Seed.class); + seeds.add(mockSeed); + + SeedStoreManager mockSeedStoreManager = mock(SeedStoreManager.class); + when(mockSeedStoreManager.getSeedStore(SeedStoreSubType.HAZELCAST)).thenReturn(mockSeedStore); + + when(mockSeed.getLocation()).thenReturn("127.0.0.1:6991"); + when(mockSeedStore.get()).thenReturn(seeds); + + StateStoreFactory factory = new HazelcastStateStoreFactory(); + stateStoreProvider = new LocalStateStoreProviderTest(mockSeedStoreManager); + stateStoreProvider.addStateStoreFactory(factory); + createStateStoreCluster("6991"); + stateStoreProvider.loadStateStore(); + } + + @AfterTest + private void cleanUp() + { + } + + @BeforeMethod + public void setUp() throws Exception + { + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + pipelineContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false); + } + + private StateStore createStateStoreCluster(String port) + { + Map config = new HashMap<>(); + config.put("hazelcast.discovery.mode", "tcp-ip"); + config.put("state-store.cluster", "test-cluster"); + config.put(DISCOVERY_PORT_CONFIG_NAME, port); + + StateStoreBootstrapper bootstrapper = new HazelcastStateStoreBootstrapper(); + return bootstrapper.bootstrap(ImmutableSet.of("127.0.0.1:" + port), config); + } + + @AfterMethod + public void tearDown() + { + executor.shutdownNow(); + scheduledExecutor.shutdownNow(); + } + + private void verifyPassthrough(Operator operator, List types, Page... pages) + { + List inputPages = ImmutableList.copyOf(pages); + List outputPages = toPages(operator, inputPages.iterator()); + MaterializedResult actual = toMaterializedResult(pipelineContext.getSession(), types, outputPages); + MaterializedResult expected = toMaterializedResult(pipelineContext.getSession(), types, inputPages); + assertEquals(actual, expected); + } + + private DynamicFilterSourceOmniOperator.DynamicFilterSourceOmniOperatorFactory createOperatorFactory( + DynamicFilter.Type dfType, FeaturesConfig.DynamicFilterDataType dataType, int partitionCount, + List types, Channel... buildChannels) + { + NodeInfo nodeInfo = new NodeInfo("test"); + Multimap probeSymbols = MultimapBuilder.treeKeys().arrayListValues().build(); + Map buildChannelMap = new HashMap<>(); + Arrays.stream(buildChannels).map(channel -> buildChannelMap.put(channel.getFilterId(), channel.getIndex())); + Arrays.stream(buildChannels).map(channel -> probeSymbols.put(channel.getFilterId(), new Symbol(String.valueOf(channel.getIndex())))); + + TaskId taskId = new TaskId("test0.0"); + LocalDynamicFilter localDynamicFilter = new LocalDynamicFilter(probeSymbols, + buildChannelMap, partitionCount, dfType, dataType, 0.1D, taskId, stateStoreProvider); + + return new DynamicFilterSourceOmniOperator.DynamicFilterSourceOmniOperatorFactory(0, + new PlanNodeId("PLAN_NODE_ID"), localDynamicFilter.getValueConsumer(), + Arrays.stream(buildChannels).collect(toList()), getDynamicFilteringMaxPerDriverValueCount(TEST_SESSION), + getDynamicFilteringMaxPerDriverSize(TEST_SESSION), types); + } + + private DynamicFilterSourceOperator createOperator( + DynamicFilterSourceOmniOperator.DynamicFilterSourceOmniOperatorFactory operatorFactory) + { + return operatorFactory.createOperator(pipelineContext.addDriverContext()); + } + + @Test + private void testCollectNoFilters() + { + OperatorFactory operatorFactory = createOperatorFactory(LOCAL, HASHSET, 1, ImmutableList.of()); + verifyPassthrough( + createOperator( + (DynamicFilterSourceOmniOperator.DynamicFilterSourceOmniOperatorFactory) operatorFactory), + ImmutableList.of(BIGINT), + transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, new Page(createLongsBlock(1, 2, 3)))); + operatorFactory.noMoreOperators(); + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestEnforceSingleRowOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestEnforceSingleRowOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..03b68f5fb35013c030e1bcf18d4f54c7d6a3e2b1 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestEnforceSingleRowOmniOperator.java @@ -0,0 +1,110 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import io.prestosql.operator.DriverContext; +import io.prestosql.operator.OperatorFactory; +import io.prestosql.operator.TaskContext; +import io.prestosql.spi.Page; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.plan.PlanNodeId; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; + +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.prestosql.RowPagesBuilder.rowPagesBuilder; +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.testing.TestingTaskContext.createTaskContext; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static nova.hetu.olk.tool.OperatorUtils.transferToOffHeapPages; +import static nova.hetu.olk.tool.TestOperatorUtils.assertPagesEquals; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestEnforceSingleRowOmniOperator +{ + private ExecutorService executor; + private ScheduledExecutorService scheduledExecutor; + private TaskContext taskContext; + + @BeforeMethod + public void setUp() + { + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + taskContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION); + } + + @AfterMethod(alwaysRun = true) + public void tearDown() + { + executor.shutdownNow(); + scheduledExecutor.shutdownNow(); + } + + @Test + public void testEnforceSingleRowOmniOperator() + { + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + OperatorFactory factory = new EnforceSingleRowOmniOperator.EnforceSingleRowOmniOperatorFactory(0, new PlanNodeId("plan-node-0"), ImmutableList.of(BIGINT)); + assertEquals(factory.getSourceTypes(), Arrays.asList(BIGINT)); + assertEquals(factory.isExtensionOperatorFactory(), true); + factory = factory.duplicate(); + EnforceSingleRowOmniOperator operator = (EnforceSingleRowOmniOperator) factory.createOperator(driverContext); + factory.noMoreOperators(); + + List input = rowPagesBuilder(BIGINT).addSequencePage(1, 0).addSequencePage(2, 0).build(); + List offHeapInput = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + operator.addInput(offHeapInput.get(0)); + + boolean caught = false; + try { + operator.addInput(offHeapInput.get(1)); + } + catch (PrestoException e) { + caught = true; + } + assertTrue(caught, "Operator didn't catch input of position count 2."); + + assertEquals(operator.getOutput(), null); + operator.finish(); + assertPagesEquals(Arrays.asList(BIGINT), Arrays.asList(operator.getOutput()), + Arrays.asList(offHeapInput.get(0))); + operator.close(); + } + + private Map createExpectedMapping() + { + Map expectedMapping = new HashMap<>(); + expectedMapping.put("operatorContext", 0); + expectedMapping.put("finishing", false); + return expectedMapping; + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestEqualsFunction.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestEqualsFunction.java new file mode 100644 index 0000000000000000000000000000000000000000..333ced52598dbe2a0e8e78e4425e3ba0c26f063b --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestEqualsFunction.java @@ -0,0 +1,78 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import io.prestosql.spi.block.VariableWidthBlock; +import nova.hetu.olk.BlockUtil; +import org.testng.annotations.Test; + +public class TestEqualsFunction +{ + int rowSize = 1000000; + int width = 10; + int round = 10; + + @Test + public void testSameBlockEqualsPerf() + { + VariableWidthBlock variableWidthBlock1 = (VariableWidthBlock) BlockUtil.buildVarcharBlock(rowSize, width, 10); + VariableWidthBlock variableWidthBlock2 = (VariableWidthBlock) variableWidthBlock1.copyRegion(0, + variableWidthBlock1.getPositionCount()); + + boolean isEqual = false; + double sum = 0; + + for (int j = 0; j < round; j++) { + long startTime = System.nanoTime(); + + for (int i = 0; i < rowSize; i++) { + isEqual = variableWidthBlock1.equals(i, 0, variableWidthBlock2, i, 0, width); + } + + long endTime = System.nanoTime(); + long duration = (endTime - startTime) / 1000_000; + System.out.println("Round: " + (j + 1) + " time: " + duration + " ms"); + sum += duration; + } + System.out.println("isEqual: " + isEqual); + System.out.println("avg time: " + sum / round + " ms"); + } + + @Test + public void testDiffBlockEqualsPerf() + { + VariableWidthBlock variableWidthBlock1 = (VariableWidthBlock) BlockUtil.buildVarcharBlock(rowSize, width, 10); + VariableWidthBlock variableWidthBlock2 = (VariableWidthBlock) BlockUtil.buildVarcharBlock(rowSize, width, 20); + + boolean isEqual = true; + double sum = 0; + + for (int j = 0; j < round; j++) { + long startTime = System.nanoTime(); + + for (int i = 0; i < rowSize; i++) { + isEqual = variableWidthBlock1.equals(i, 0, variableWidthBlock2, i, 0, width); + } + + long endTime = System.nanoTime(); + long duration = (endTime - startTime) / 1000_000; + System.out.println("Round: " + (j + 1) + " time: " + duration + " ms"); + sum += duration; + } + System.out.println("isEqual: " + isEqual); + System.out.println("avg time: " + sum / round + " ms"); + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestHashAggregationOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestHashAggregationOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..0d3328daebf77d1ed8abbe6a5bcbce5463727237 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestHashAggregationOmniOperator.java @@ -0,0 +1,440 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import io.airlift.units.DataSize; +import io.prestosql.metadata.Metadata; +import io.prestosql.operator.DriverContext; +import io.prestosql.operator.HashAggregationOperator; +import io.prestosql.operator.aggregation.InternalAggregationFunction; +import io.prestosql.spi.Page; +import io.prestosql.spi.PageBuilder; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.connector.QualifiedObjectName; +import io.prestosql.spi.function.Signature; +import io.prestosql.spi.plan.AggregationNode; +import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.gen.JoinCompiler; +import io.prestosql.testing.MaterializedResult; +import io.prestosql.testing.TestingTaskContext; +import nova.hetu.olk.tool.BlockUtils; +import nova.hetu.olk.tool.OperatorUtils; +import nova.hetu.omniruntime.constants.FunctionType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; +import org.testng.internal.collections.Ints; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; + +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static io.airlift.units.DataSize.succinctBytes; +import static io.prestosql.RowPagesBuilder.rowPagesBuilder; +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.operator.OperatorAssertion.assertOperatorEquals; +import static io.prestosql.operator.OperatorAssertion.assertPagesEqualIgnoreOrder; +import static io.prestosql.operator.OperatorAssertion.toPages; +import static io.prestosql.spi.function.FunctionKind.AGGREGATE; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.testing.MaterializedResult.resultBuilder; +import static io.prestosql.testing.TestingTaskContext.createTaskContext; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static nova.hetu.olk.tool.OperatorUtils.transferToOffHeapPages; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_AVG; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_COLUMN; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_MAX; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_MIN; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_SUM; +import static org.testng.Assert.assertEquals; + +@Test(singleThreaded = true) +public class TestHashAggregationOmniOperator +{ + private static final Metadata metadata = createTestMetadataManager(); + + private ExecutorService executor; + + private ScheduledExecutorService scheduledExecutor; + + @BeforeMethod + public void setUp() + { + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + } + + @DataProvider(name = "hashEnabled") + public static Object[][] hashEnabled() + { + return new Object[][]{{true}, {false}}; + } + + @AfterMethod(alwaysRun = true) + public void tearDown() + { + executor.shutdownNow(); + scheduledExecutor.shutdownNow(); + } + + private List builderPage() + { + List dataTypes = new ArrayList<>(); + dataTypes.add(BIGINT); + dataTypes.add(BIGINT); + dataTypes.add(BIGINT); + dataTypes.add(BIGINT); + + List inputPages = new ArrayList<>(); + for (int k = 0; k < totalPageCount; k++) { + PageBuilder pb = PageBuilder.withMaxPageSize(Integer.MAX_VALUE, dataTypes); + BlockBuilder group1 = pb.getBlockBuilder(0); + BlockBuilder group2 = pb.getBlockBuilder(1); + BlockBuilder sum1 = pb.getBlockBuilder(2); + BlockBuilder sum2 = pb.getBlockBuilder(3); + + for (int i = 0; i < pageDistinctCount; i++) { + for (int j = 0; j < pageDistinctValueRepeatCount; j++) { + group1.writeLong(i); + group2.writeLong(i); + sum1.writeLong(1); + sum2.writeLong(1); + pb.declarePosition(); + } + } + Page build = pb.build(); + inputPages.add(build); + } + return inputPages; + } + + final int pageDistinctCount = 4; + + final int pageDistinctValueRepeatCount = 250; + + final int totalPageCount = 10; + + final int threadNum = 10; + + @Test(invocationCount = 1) + public void testCountAggregation() + { + int[] omniGroupByChannels = {0, 1}; + DataType[] omniGroupByTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] omniAggregationChannels = {2}; + DataType[] omniAggregationTypes = {LongDataType.LONG, LongDataType.LONG}; + FunctionType[] omniAggregator = {OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_COUNT_ALL}; + int[] maskChannels = {-1, -1}; + List inAndOutputTypes = new ArrayList<>(); + inAndOutputTypes.add(new DataType[]{LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}); + inAndOutputTypes + .add(new DataType[]{LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}); + + // expected + DriverContext driverContext = createDriverContext(Integer.MAX_VALUE); + MaterializedResult expected = getExpectedMaterializedRows(driverContext); + + CountDownLatch countDownLatch = new CountDownLatch(threadNum); + CopyOnWriteArrayList> resultList = new CopyOnWriteArrayList<>(); + + for (int i = 0; i < threadNum; i++) { + int id = i; + Thread thread = new Thread(() -> { + try { + List input = builderPage(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + input); + List pages; + + HashAggregationOmniOperator.HashAggregationOmniOperatorFactory hashAggregationOmniOperatorFactory = new HashAggregationOmniOperator.HashAggregationOmniOperatorFactory( + id, new PlanNodeId(String.valueOf(id)), omniGroupByChannels, omniGroupByTypes, + omniAggregationChannels, omniAggregationTypes, omniAggregator, inAndOutputTypes, + maskChannels); + pages = toPages(hashAggregationOmniOperatorFactory, driverContext, offHeapPages, false); + resultList.add(pages); + } + finally { + countDownLatch.countDown(); + } + }); + thread.start(); + } + try { + countDownLatch.await(); + } + catch (InterruptedException e) { + e.printStackTrace(); + } + + assertEquals(resultList.size(), threadNum); + + for (List pages : resultList) { + assertPagesEqualIgnoreOrder(driverContext, pages, expected, false, Optional.empty()); + for (Page page : pages) { + BlockUtils.freePage(page); + } + } + resultList.clear(); + } + + @Test(invocationCount = 1) + public void testHashAggregation() + { + int[] omniGroupByChannels = {0, 1}; + DataType[] omniGroupByTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] omniAggregationChannels = {2, 3}; + DataType[] omniAggregationTypes = {LongDataType.LONG, LongDataType.LONG}; + FunctionType[] omniAggregator = {OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_SUM}; + int[] maskChannels = {-1, -1}; + List inAndOutputTypes = new ArrayList<>(); + inAndOutputTypes + .add(new DataType[]{LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}); + inAndOutputTypes + .add(new DataType[]{LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}); + + // expected + DriverContext driverContext = createDriverContext(Integer.MAX_VALUE); + MaterializedResult expected = getExpectedMaterializedRows(driverContext); + + CountDownLatch countDownLatch = new CountDownLatch(threadNum); + CopyOnWriteArrayList> resultList = new CopyOnWriteArrayList<>(); + + for (int i = 0; i < threadNum; i++) { + int id = i; + Thread thread = new Thread(() -> { + try { + List input = builderPage(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + input); + List pages; + + HashAggregationOmniOperator.HashAggregationOmniOperatorFactory hashAggregationOmniOperatorFactory = new HashAggregationOmniOperator.HashAggregationOmniOperatorFactory( + id, new PlanNodeId(String.valueOf(id)), omniGroupByChannels, omniGroupByTypes, + omniAggregationChannels, omniAggregationTypes, omniAggregator, inAndOutputTypes, + maskChannels); + pages = toPages(hashAggregationOmniOperatorFactory, driverContext, offHeapPages, false); + resultList.add(pages); + } + finally { + countDownLatch.countDown(); + } + }); + thread.start(); + } + try { + countDownLatch.await(); + } + catch (InterruptedException e) { + e.printStackTrace(); + } + + assertEquals(resultList.size(), threadNum); + + for (List pages : resultList) { + assertPagesEqualIgnoreOrder(driverContext, pages, expected, false, Optional.empty()); + } + } + + private MaterializedResult getExpectedMaterializedRows(DriverContext driverContext) + { + MaterializedResult.Builder expectedBuilder = resultBuilder(driverContext.getSession(), BIGINT, BIGINT, BIGINT, + BIGINT); + long sum = totalPageCount * pageDistinctValueRepeatCount; + for (int i = 0; i < pageDistinctCount; i++) { + expectedBuilder.row((long) i, (long) i, sum, sum); + } + MaterializedResult expected = expectedBuilder.build(); + return expected; + } + + protected static final JoinCompiler JOIN_COMPILER = new JoinCompiler(createTestMetadataManager()); + + private HashAggregationOperator.HashAggregationOperatorFactory getOriginalAggFactory(int id) + { + InternalAggregationFunction bigintSum = metadata.getFunctionAndTypeManager() + .getAggregateFunctionImplementation(new Signature(QualifiedObjectName.valueOfDefaultFunction("sum"), + AGGREGATE, BIGINT.getTypeSignature(), BIGINT.getTypeSignature())); + HashAggregationOperator.HashAggregationOperatorFactory aggregationOperatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory( + id, new PlanNodeId(String.valueOf(id)), ImmutableList.of(BIGINT, BIGINT), Ints.asList(0, 1), + ImmutableList.of(), AggregationNode.Step.SINGLE, + ImmutableList.of(bigintSum.bind(ImmutableList.of(2), Optional.empty()), + bigintSum.bind(ImmutableList.of(3), Optional.empty())), + Optional.empty(), Optional.empty(), 100_000, Optional.of(new DataSize(16, MEGABYTE)), JOIN_COMPILER, + false); + return aggregationOperatorFactory; + } + + private DriverContext createDriverContext(long memoryLimit) + { + return TestingTaskContext.builder(executor, scheduledExecutor, TEST_SESSION) + .setMemoryPoolSize(succinctBytes(memoryLimit)).build().addPipelineContext(0, true, true, false) + .addDriverContext(); + } + + @Test(invocationCount = 1) + public void testHashAggregationWithDiffLayout() + { + int[] omniGroupByChannels = {3, 0}; + DataType[] omniGroupByTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] omniAggregationChannels = {2, 1}; + DataType[] omniAggregationTypes = {LongDataType.LONG, LongDataType.LONG}; + FunctionType[] omniAggregator = {OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_SUM}; + int[] maskChannels = {-1, -1}; + List inAndOutputTypes = new ArrayList<>(); + inAndOutputTypes + .add(new DataType[]{LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}); + inAndOutputTypes + .add(new DataType[]{LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}); + + DriverContext driverContext = createDriverContext(Integer.MAX_VALUE); + MaterializedResult.Builder expectedBuilder = resultBuilder(driverContext.getSession(), BIGINT, BIGINT, BIGINT, + BIGINT); + long sum = totalPageCount * pageDistinctValueRepeatCount * 10; + for (int i = 0; i < pageDistinctCount; i++) { + expectedBuilder.row((long) i + 1, (long) i, sum, sum); + } + MaterializedResult expected = expectedBuilder.build(); + + CountDownLatch countDownLatch = new CountDownLatch(threadNum); + CopyOnWriteArrayList> resultList = new CopyOnWriteArrayList<>(); + + for (int i = 0; i < threadNum; i++) { + int id = i; + Thread thread = new Thread(() -> { + try { + List input = builderPageWithDiffLayout(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + input); + List pages; + + HashAggregationOmniOperator.HashAggregationOmniOperatorFactory hashAggregationOmniOperatorFactory = new HashAggregationOmniOperator.HashAggregationOmniOperatorFactory( + id, new PlanNodeId(String.valueOf(id)), omniGroupByChannels, omniGroupByTypes, + omniAggregationChannels, omniAggregationTypes, omniAggregator, inAndOutputTypes, + maskChannels); + pages = toPages(hashAggregationOmniOperatorFactory, driverContext, offHeapPages, false); + resultList.add(pages); + } + finally { + countDownLatch.countDown(); + } + }); + thread.start(); + } + try { + countDownLatch.await(); + } + catch (InterruptedException e) { + e.printStackTrace(); + } + + assertEquals(resultList.size(), threadNum); + + for (List pages : resultList) { + assertPagesEqualIgnoreOrder(driverContext, pages, expected, false, Optional.empty()); + } + } + + @Test(invocationCount = 1) + public void testHashAggregationWithMarkDistinct() + { + List input = rowPagesBuilder(BIGINT, BIGINT, BIGINT, BIGINT, BIGINT, BIGINT, BOOLEAN, BOOLEAN, BOOLEAN, + BOOLEAN, BOOLEAN).row(10L, 20L, 20L, 20L, 20L, 20L, true, true, true, true, true) + .row(10L, 10L, 10L, 10L, 10L, 10L, true, true, true, true, true).pageBreak() + .row(10L, 30L, 30L, 30L, 30L, 30L, false, false, false, false, false) + .row(10L, 10L, 10L, 10L, 10L, 10L, true, true, true, true, true).build(); + // transfer on-heap page to off-heap + List offHeapInput = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + int id = 0; + int[] groupByChannels = {0}; + DataType[] groupByTypes = {LongDataType.LONG}; + FunctionType[] aggregatorTypes = {OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_SUM, + OMNI_AGGREGATION_TYPE_AVG, OMNI_AGGREGATION_TYPE_MAX, OMNI_AGGREGATION_TYPE_MIN}; + int[] aggregationChannels = {1, 2, 3, 4, 5}; + DataType[] aggregationTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, + LongDataType.LONG}; + List inAndOutputTypes = new ArrayList<>(); + inAndOutputTypes + .add(new DataType[]{LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, + LongDataType.LONG, LongDataType.LONG}); + inAndOutputTypes + .add(new DataType[]{LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, + DoubleDataType.DOUBLE, LongDataType.LONG, LongDataType.LONG}); + int[] maskChannels = {6, 7, 8, 9, 10}; + HashAggregationOmniOperator.HashAggregationOmniOperatorFactory operatorFactory = new HashAggregationOmniOperator.HashAggregationOmniOperatorFactory( + id, new PlanNodeId(String.valueOf(id)), groupByChannels, groupByTypes, aggregationChannels, + aggregationTypes, aggregatorTypes, inAndOutputTypes, maskChannels); + + DriverContext driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, BIGINT, BIGINT, DOUBLE, BIGINT, + BIGINT).row(10L, 3L, 40L, (40D / 3D), 20L, 10L).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapInput, expected); + assertEquals(driverContext.getSystemMemoryUsage(), 0); + assertEquals(driverContext.getMemoryUsage(), 0); + } + + private List builderPageWithDiffLayout() + { + List dataTypes = new ArrayList<>(); + dataTypes.add(BIGINT); + dataTypes.add(BIGINT); + dataTypes.add(BIGINT); + dataTypes.add(BIGINT); + + List inputPages = new ArrayList<>(); + for (int k = 0; k < totalPageCount; k++) { + PageBuilder pb = PageBuilder.withMaxPageSize(Integer.MAX_VALUE, dataTypes); + BlockBuilder group1 = pb.getBlockBuilder(1); + BlockBuilder group2 = pb.getBlockBuilder(0); + BlockBuilder sum1 = pb.getBlockBuilder(2); + BlockBuilder sum2 = pb.getBlockBuilder(3); + + for (int i = 0; i < pageDistinctCount; i++) { + for (int j = 0; j < pageDistinctValueRepeatCount; j++) { + group1.writeLong(i); + group2.writeLong(i + 1); + sum1.writeLong(10); + sum2.writeLong(10); + pb.declarePosition(); + } + } + Page build = pb.build(); + inputPages.add(build); + } + return inputPages; + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestHashFunction.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestHashFunction.java new file mode 100644 index 0000000000000000000000000000000000000000..c26b5d6f9bec9e571ab1e444895bb7e2067a8dc7 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestHashFunction.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import io.airlift.slice.Slice; +import io.prestosql.spi.block.VariableWidthBlock; +import nova.hetu.olk.BlockUtil; +import org.testng.annotations.Test; + +import static io.airlift.slice.XxHash64.hash; + +public class TestHashFunction +{ + int rowSize = 1000000; + int width = 10; + int round = 10; + + @Test + public void testVarcharHashPerf() + { + VariableWidthBlock variableWidthBlock = (VariableWidthBlock) BlockUtil.buildVarcharBlock(rowSize, width, 10); + Slice[] slice = BlockUtil.getBlockSlices(variableWidthBlock, rowSize, width); + + long hashVal = 0; + double sum = 0; + for (int j = 0; j < round; j++) { + long startTime = System.nanoTime(); + + for (int i = 0; i < rowSize; i++) { + hashVal = hash(10, slice[i]); + } + + long endTime = System.nanoTime(); + long duration = (endTime - startTime) / 1000_000; + System.out.println("Round: " + (j + 1) + " time: " + duration + " ms"); + sum += duration; + } + System.out.println("hashVal: " + hashVal); + System.out.println("avg time: " + sum / round + " ms"); + } + + @Test + public void testLongHashPerf() + { + long hashVal = 0; + double sum = 0; + long start = 10000000; + + for (int j = 0; j < round; j++) { + long startTime = System.nanoTime(); + + for (long i = start; i < rowSize + start; i++) { + hashVal = hash(i); + } + + long endTime = System.nanoTime(); + long duration = (endTime - startTime) / 1000_000; + System.out.println("Round: " + (j + 1) + " time: " + duration + " ms"); + sum += duration; + } + System.out.println("hashVal: " + hashVal); + System.out.println("avg time: " + sum / round + " ms"); + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestHashJoinOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestHashJoinOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..0771f8b0cef322bdb44d4a3ccd50a7d36d31b0c3 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestHashJoinOmniOperator.java @@ -0,0 +1,1089 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.primitives.Ints; +import io.airlift.units.DataSize; +import io.prestosql.RowPagesBuilder; +import io.prestosql.execution.Lifespan; +import io.prestosql.operator.Driver; +import io.prestosql.operator.DriverContext; +import io.prestosql.operator.JoinBridgeManager; +import io.prestosql.operator.LookupSourceFactory; +import io.prestosql.operator.LookupSourceProvider; +import io.prestosql.operator.Operator; +import io.prestosql.operator.OperatorFactory; +import io.prestosql.operator.PartitionedLookupSourceFactory; +import io.prestosql.operator.PipelineContext; +import io.prestosql.operator.TaskContext; +import io.prestosql.operator.ValuesOperator.ValuesOperatorFactory; +import io.prestosql.operator.exchange.LocalExchange.LocalExchangeFactory; +import io.prestosql.operator.exchange.LocalExchange.LocalExchangeSinkFactoryId; +import io.prestosql.operator.exchange.LocalExchangeSinkOperator.LocalExchangeSinkOperatorFactory; +import io.prestosql.operator.exchange.LocalExchangeSourceOperator.LocalExchangeSourceOperatorFactory; +import io.prestosql.spi.Page; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.type.Type; +import io.prestosql.testing.MaterializedResult; +import io.prestosql.testing.TestingTaskContext; +import nova.hetu.olk.tool.OperatorUtils; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.function.Function; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.concurrent.MoreFutures.getFutureValue; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.prestosql.RowPagesBuilder.rowPagesBuilder; +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.operator.OperatorAssertion.assertOperatorEquals; +import static io.prestosql.operator.PipelineExecutionStrategy.UNGROUPED_EXECUTION; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static io.prestosql.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.testng.Assert.assertNull; + +@Test(singleThreaded = true) +public class TestHashJoinOmniOperator +{ + private static final int PARTITION_COUNT = 4; + private ExecutorService executor; + private ScheduledExecutorService scheduledExecutor; + + @BeforeMethod + public void setUp() + { + // Before/AfterMethod is chosen here because the executor needs to be shutdown + // after every single test case to terminate outstanding threads, if any. + + // The line below is the same as newCachedThreadPool(daemonThreadsNamed(...)) + // except RejectionExecutionHandler. + // RejectionExecutionHandler is set to DiscardPolicy (instead of the default + // AbortPolicy) here. + // Otherwise, a large number of RejectedExecutionException will flood logging, + // resulting in Travis failure. + executor = new ThreadPoolExecutor(0, Integer.MAX_VALUE, 60L, SECONDS, new SynchronousQueue(), + daemonThreadsNamed("test-executor-%s"), new ThreadPoolExecutor.DiscardPolicy()); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + } + + @AfterMethod(alwaysRun = true) + public void tearDown() + { + executor.shutdownNow(); + scheduledExecutor.shutdownNow(); + } + + @DataProvider(name = "hashJoinTestValues") + public static Object[][] hashJoinTestValuesProvider() + { + return new Object[][]{{true, true, true}, {true, true, false}, {true, false, true}, {true, false, false}, + {false, true, true}, {false, true, false}, {false, false, true}, {false, false, false}}; + } + + @Test(dataProvider = "hashJoinTestValues") + public void testInnerJoin(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + { + TaskContext taskContext = createTaskContext(); + + // build factory + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), + ImmutableList.of(BIGINT, BIGINT, BIGINT)) + .addSequencePage(10, 20, 30, 40); + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.empty()); + JoinBridgeManager lookupSourceFactory = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), + ImmutableList.of(BIGINT, BIGINT, BIGINT)); + List probeInput = probePages.addSequencePage(1000, 0, 1000, 2000).build(); + + OperatorFactory joinOperatorFactory = innerJoinOperatorFactory(lookupSourceFactory, probePages, + buildSideSetup.getOperatorFactory()); + + // build drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + + // expected + MaterializedResult expected = MaterializedResult + .resultBuilder(taskContext.getSession(), + concat(probePages.getTypesWithoutHash(), buildPages.getTypesWithoutHash())) + .row(20L, 1020L, 2020L, 20L, 30L, 40L).row(21L, 1021L, 2021L, 21L, 31L, 41L) + .row(22L, 1022L, 2022L, 22L, 32L, 42L).row(23L, 1023L, 2023L, 23L, 33L, 43L) + .row(24L, 1024L, 2024L, 24L, 34L, 44L).row(25L, 1025L, 2025L, 25L, 35L, 45L) + .row(26L, 1026L, 2026L, 26L, 36L, 46L).row(27L, 1027L, 2027L, 27L, 37L, 47L) + .row(28L, 1028L, 2028L, 28L, 38L, 48L).row(29L, 1029L, 2029L, 29L, 39L, 49L) + .build(); + + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, true, + getHashChannels(probePages, buildPages)); + joinOperatorFactory.noMoreOperators(); + } + + @Test(dataProvider = "hashJoinTestValues") + public void testInnerJoinWithNullProbe(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + { + TaskContext taskContext = createTaskContext(); + + // build factory + List buildTypes = ImmutableList.of(BIGINT); + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), buildTypes).row(1L) + .row(2L).row(3L); + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.empty()); + JoinBridgeManager lookupSourceFactory = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + List probeTypes = ImmutableList.of(BIGINT); + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); + List probeInput = probePages.row(1L).row((String) null).row((String) null).row(1L).row(2L).build(); + + OperatorFactory joinOperatorFactory = innerJoinOperatorFactory(lookupSourceFactory, probePages, + buildSideSetup.getBuildOperatorFactory()); + + // build drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + + // expected + MaterializedResult expected = MaterializedResult + .resultBuilder(taskContext.getSession(), concat(probeTypes, buildPages.getTypesWithoutHash())) + .row(1L, 1L).row(1L, 1L).row(2L, 2L).build(); + + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, true, + getHashChannels(probePages, buildPages)); + joinOperatorFactory.noMoreOperators(); + } + + @Test(dataProvider = "hashJoinTestValues") + public void testInnerJoinWithNullBuild(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + { + TaskContext taskContext = createTaskContext(); + + // build factory + List buildTypes = ImmutableList.of(BIGINT); + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), buildTypes).row(1L) + .row((String) null).row((String) null).row(1L).row(2L); + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.empty()); + JoinBridgeManager lookupSourceFactory = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + List probeTypes = ImmutableList.of(BIGINT); + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); + List probeInput = probePages.row(1L).row(2L).row(3L).build(); + + OperatorFactory joinOperatorFactory = innerJoinOperatorFactory(lookupSourceFactory, probePages, + buildSideSetup.getBuildOperatorFactory()); + + // build drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + + // expected + MaterializedResult expected = MaterializedResult + .resultBuilder(taskContext.getSession(), concat(probeTypes, buildTypes)).row(1L, 1L).row(1L, 1L) + .row(2L, 2L).build(); + + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, true, + getHashChannels(probePages, buildPages)); + joinOperatorFactory.noMoreOperators(); + } + + @Test(dataProvider = "hashJoinTestValues") + public void testInnerJoinWithNullOnBothSides(boolean parallelBuild, boolean probeHashEnabled, + boolean buildHashEnabled) + { + // TODO: failed since taking null as zero now, open when omni-runtime support + TaskContext taskContext = createTaskContext(); + + // build factory + List buildTypes = ImmutableList.of(BIGINT); + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), buildTypes).row(1L) + .row((String) null).row((String) null).row(1L).row(2L); + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.empty()); + JoinBridgeManager lookupSourceFactory = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + List probeTypes = ImmutableList.of(BIGINT); + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); + List probeInput = probePages.row(1L).row(2L).row((String) null).row(3L).build(); + + OperatorFactory joinOperatorFactory = innerJoinOperatorFactory(lookupSourceFactory, probePages, + buildSideSetup.getBuildOperatorFactory()); + + // build drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + + // expected + MaterializedResult expected = MaterializedResult + .resultBuilder(taskContext.getSession(), concat(probeTypes, buildTypes)).row(1L, 1L).row(1L, 1L) + .row(2L, 2L).build(); + + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, true, + getHashChannels(probePages, buildPages)); + joinOperatorFactory.noMoreOperators(); + } + + @Test(dataProvider = "hashJoinTestValues") + public void testProbeOuterJoin(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + { + TaskContext taskContext = createTaskContext(); + + // build factory + List buildTypes = ImmutableList.of(VARCHAR, BIGINT, BIGINT); + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), + ImmutableList.of(VARCHAR, BIGINT, BIGINT)).addSequencePage(10, 20, 30, 40); + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.empty()); + JoinBridgeManager lookupSourceFactoryManager = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + List probeTypes = ImmutableList.of(VARCHAR, BIGINT, BIGINT); + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); + List probeInput = probePages.addSequencePage(15, 20, 1020, 2020).build(); + + OperatorFactory joinOperatorFactory = LookupJoinOmniOperators.probeOuterJoin(0, new PlanNodeId("test"), + lookupSourceFactoryManager, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), + Optional.empty(), OptionalInt.of(1), buildSideSetup.getBuildOperatorFactory()); + + // build drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + + // expected + MaterializedResult expected = MaterializedResult + .resultBuilder(taskContext.getSession(), concat(probeTypes, buildTypes)) + .row("20", 1020L, 2020L, "20", 30L, 40L).row("21", 1021L, 2021L, "21", 31L, 41L) + .row("22", 1022L, 2022L, "22", 32L, 42L).row("23", 1023L, 2023L, "23", 33L, 43L) + .row("24", 1024L, 2024L, "24", 34L, 44L).row("25", 1025L, 2025L, "25", 35L, 45L) + .row("26", 1026L, 2026L, "26", 36L, 46L).row("27", 1027L, 2027L, "27", 37L, 47L) + .row("28", 1028L, 2028L, "28", 38L, 48L).row("29", 1029L, 2029L, "29", 39L, 49L) + .row("30", 1030L, 2030L, null, null, null).row("31", 1031L, 2031L, null, null, null) + .row("32", 1032L, 2032L, null, null, null).row("33", 1033L, 2033L, null, null, null) + .row("34", 1034L, 2034L, null, null, null).build(); + + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, true, + getHashChannels(probePages, buildPages)); + joinOperatorFactory.noMoreOperators(); + } + + @Test(dataProvider = "hashJoinTestValues") + public void testProbeOuterJoinWithFilterFunction(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + { + TaskContext taskContext = createTaskContext(); + + String filterFunction = "{" + + " \"exprType\": \"BINARY\"," + + " \"returnType\": 4," + + " \"operator\": \"GREATER_THAN_OR_EQUAL\"," + + " \"left\": " + + " {" + + " \"exprType\": \"FIELD_REFERENCE\"," + + " \"dataType\": 2," + + " \"colVal\": 1" + + " }," + + " \"right\": " + + " {" + + " \"exprType\": \"LITERAL\"," + + " \"dataType\": 2," + + " \"isNull\": false," + + " \"value\": 1025" + + " }" + + "}"; + + // build factory + List buildTypes = ImmutableList.of(VARCHAR, BIGINT, BIGINT); + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), + ImmutableList.of(VARCHAR, BIGINT, BIGINT)).addSequencePage(10, 20, 30, 40); + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.of(filterFunction)); + JoinBridgeManager lookupSourceFactory = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + List probeTypes = ImmutableList.of(VARCHAR, BIGINT, BIGINT); + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); + List probeInput = probePages.addSequencePage(15, 20, 1020, 2020).build(); + OperatorFactory joinOperatorFactory = LookupJoinOmniOperators.probeOuterJoin(0, new PlanNodeId("test"), + lookupSourceFactory, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), + Optional.empty(), OptionalInt.of(1), buildSideSetup.getBuildOperatorFactory()); + + // build drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + + // expected + MaterializedResult expected = MaterializedResult.resultBuilder(taskContext.getSession(), concat(probeTypes, buildTypes)) + .row("20", 1020L, 2020L, null, null, null).row("21", 1021L, 2021L, null, null, null) + .row("22", 1022L, 2022L, null, null, null).row("23", 1023L, 2023L, null, null, null) + .row("24", 1024L, 2024L, null, null, null).row("25", 1025L, 2025L, "25", 35L, 45L) + .row("26", 1026L, 2026L, "26", 36L, 46L).row("27", 1027L, 2027L, "27", 37L, 47L) + .row("28", 1028L, 2028L, "28", 38L, 48L).row("29", 1029L, 2029L, "29", 39L, 49L) + .row("30", 1030L, 2030L, null, null, null).row("31", 1031L, 2031L, null, null, null) + .row("32", 1032L, 2032L, null, null, null).row("33", 1033L, 2033L, null, null, null) + .row("34", 1034L, 2034L, null, null, null).build(); + + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, true, + getHashChannels(probePages, buildPages)); + joinOperatorFactory.noMoreOperators(); + } + + @Test(dataProvider = "hashJoinTestValues") + public void testOuterJoinWithNullProbe(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + { + TaskContext taskContext = createTaskContext(); + + // build factory + List buildTypes = ImmutableList.of(VARCHAR); + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), buildTypes).row("a").row("b") + .row("c"); + + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.empty()); + JoinBridgeManager lookupSourceFactory = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + List probeTypes = ImmutableList.of(VARCHAR); + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); + List probeInput = probePages.row("a").row((String) null).row((String) null).row("a").row("b").build(); + OperatorFactory joinOperatorFactory = LookupJoinOmniOperators.probeOuterJoin(0, new PlanNodeId("test"), + lookupSourceFactory, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), + Optional.empty(), OptionalInt.of(1), buildSideSetup.getBuildOperatorFactory()); + + // build drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + + // expected + MaterializedResult expected = MaterializedResult + .resultBuilder(taskContext.getSession(), concat(probeTypes, buildTypes)).row("a", "a").row(null, null) + .row(null, null).row("a", "a").row("b", "b").build(); + + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, true, + getHashChannels(probePages, buildPages)); + joinOperatorFactory.noMoreOperators(); + } + + @Test(dataProvider = "hashJoinTestValues") + public void testOuterJoinWithNullProbeAndFilterFunction(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + { + TaskContext taskContext = createTaskContext(); + + String filterFunction = "{" + + " \"exprType\": \"BINARY\"," + + " \"returnType\": 4," + + " \"operator\": \"EQUAL\"," + + " \"left\": " + + " {" + + " \"exprType\": \"FIELD_REFERENCE\"," + + " \"dataType\": 15," + + " \"width\": 1," + + " \"colVal\": 0" + + " }," + + " \"right\": " + + " {" + + " \"exprType\": \"LITERAL\"," + + " \"dataType\": 15," + + " \"isNull\": false," + + " \"width\": 1," + + " \"value\": \"a\"" + + " }" + + "}"; + + // build factory + List buildTypes = ImmutableList.of(VARCHAR); + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), buildTypes).row("a").row("b") + .row("c"); + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.of(filterFunction)); + JoinBridgeManager lookupSourceFactory = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + List probeTypes = ImmutableList.of(VARCHAR); + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); + List probeInput = probePages.row("a").row((String) null).row((String) null).row("a").row("b").build(); + OperatorFactory joinOperatorFactory = LookupJoinOmniOperators.probeOuterJoin(0, new PlanNodeId("test"), + lookupSourceFactory, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), + Optional.empty(), OptionalInt.of(1), buildSideSetup.getBuildOperatorFactory()); + + // build drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + + // expected + MaterializedResult expected = MaterializedResult + .resultBuilder(taskContext.getSession(), concat(probeTypes, buildTypes)).row("a", "a").row(null, null) + .row(null, null).row("a", "a").row("b", null).build(); + + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, true, + getHashChannels(probePages, buildPages)); + joinOperatorFactory.noMoreOperators(); + } + + @Test(dataProvider = "hashJoinTestValues") + public void testOuterJoinWithNullBuild(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + { + TaskContext taskContext = createTaskContext(); + + // build factory + List buildTypes = ImmutableList.of(VARCHAR); + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), ImmutableList.of(VARCHAR)) + .row("a").row((String) null).row((String) null) + .row("a").row("b"); + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.empty()); + JoinBridgeManager lookupSourceFactory = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + List probeTypes = ImmutableList.of(VARCHAR); + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); + List probeInput = probePages.row("a").row("b").row("c").build(); + + OperatorFactory joinOperatorFactory = LookupJoinOmniOperators.probeOuterJoin(0, new PlanNodeId("test"), + lookupSourceFactory, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), + Optional.empty(), OptionalInt.of(1), buildSideSetup.getBuildOperatorFactory()); + + // build drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + + // expected + MaterializedResult expected = MaterializedResult + .resultBuilder(taskContext.getSession(), concat(probeTypes, buildTypes)).row("a", "a").row("a", "a") + .row("b", "b").row("c", null).build(); + + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, true, + getHashChannels(probePages, buildPages)); + joinOperatorFactory.noMoreOperators(); + } + + @Test(dataProvider = "hashJoinTestValues") + public void testOuterJoinWithNullBuildAndFilterFunction(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + { + TaskContext taskContext = createTaskContext(); + + String filterFunction = "{" + + " \"exprType\": \"BINARY\"," + + " \"returnType\": 4," + + " \"operator\": \"OR\"," + + " \"left\": " + + " {" + + " \"exprType\": \"BINARY\"," + + " \"returnType\": 4," + + " \"operator\": \"EQUAL\"," + + " \"left\": " + + " {" + + " \"exprType\": \"FIELD_REFERENCE\"," + + " \"dataType\": 15," + + " \"width\": 1," + + " \"colVal\": 0" + + " }," + + " \"right\": " + + " {" + + " \"exprType\": \"LITERAL\"," + + " \"dataType\": 15," + + " \"isNull\": false," + + " \"width\": 1," + + " \"value\": \"a\"" + + " }" + + " }," + + " \"right\": " + + " {" + + " \"exprType\": \"BINARY\"," + + " \"returnType\": 4," + + " \"operator\": \"EQUAL\"," + + " \"left\": " + + " {" + + " \"exprType\": \"FIELD_REFERENCE\"," + + " \"dataType\": 15," + + " \"width\": 1," + + " \"colVal\": 0" + + " }," + + " \"right\": " + + " {" + + " \"exprType\": \"LITERAL\"," + + " \"dataType\": 15," + + " \"isNull\": false," + + " \"width\": 1," + + " \"value\": \"c\"" + + " }" + + " }" + + "}"; + + // build factory + List buildTypes = ImmutableList.of(VARCHAR); + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), ImmutableList.of(VARCHAR)) + .row("a").row((String) null).row((String) null).row("a").row("b"); + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.of(filterFunction)); + JoinBridgeManager lookupSourceFactory = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + List probeTypes = ImmutableList.of(VARCHAR); + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); + List probeInput = probePages.row("a").row("b").row("c").build(); + + OperatorFactory joinOperatorFactory = LookupJoinOmniOperators.probeOuterJoin(0, new PlanNodeId("test"), + lookupSourceFactory, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), + Optional.empty(), OptionalInt.of(1), buildSideSetup.getBuildOperatorFactory()); + + // build drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + + // expected + MaterializedResult expected = MaterializedResult + .resultBuilder(taskContext.getSession(), concat(probeTypes, buildTypes)).row("a", "a").row("a", "a") + .row("b", null).row("c", null).build(); + + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, true, + getHashChannels(probePages, buildPages)); + joinOperatorFactory.noMoreOperators(); + } + + @Test(dataProvider = "hashJoinTestValues") + public void testOuterJoinWithNullOnBothSides(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + { + TaskContext taskContext = createTaskContext(); + + // build factory + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), ImmutableList.of(VARCHAR)) + .row("a").row((String) null).row((String) null).row("a").row("b"); + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.empty()); + JoinBridgeManager lookupSourceFactory = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + List probeTypes = ImmutableList.of(VARCHAR); + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); + List probeInput = probePages.row("a").row("b").row((String) null).row("c").build(); + + OperatorFactory joinOperatorFactory = LookupJoinOmniOperators.probeOuterJoin(0, new PlanNodeId("test"), + lookupSourceFactory, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), + Optional.empty(), OptionalInt.of(1), buildSideSetup.getBuildOperatorFactory()); + + // build drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + + // expected + MaterializedResult expected = MaterializedResult + .resultBuilder(taskContext.getSession(), concat(probeTypes, buildPages.getTypesWithoutHash())) + .row("a", "a").row("a", "a").row("b", "b").row(null, null).row("c", null).build(); + + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, true, + getHashChannels(probePages, buildPages)); + joinOperatorFactory.noMoreOperators(); + } + + @Test(dataProvider = "hashJoinTestValues") + public void testOuterJoinWithNullOnBothSidesAndFilterFunction(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + { + TaskContext taskContext = createTaskContext(); + + String filterFunction = "{" + + " \"exprType\": \"BINARY\"," + + " \"returnType\": 4," + + " \"operator\": \"OR\"," + + " \"left\": " + + " {" + + " \"exprType\": \"BINARY\"," + + " \"returnType\": 4," + + " \"operator\": \"EQUAL\"," + + " \"left\": " + + " {" + + " \"exprType\": \"FIELD_REFERENCE\"," + + " \"dataType\": 15," + + " \"width\": 1," + + " \"colVal\": 0" + + " }," + + " \"right\": " + + " {" + + " \"exprType\": \"LITERAL\"," + + " \"dataType\": 15," + + " \"isNull\": false," + + " \"width\": 1," + + " \"value\": \"a\"" + + " }" + + " }," + + " \"right\": " + + " {" + + " \"exprType\": \"BINARY\"," + + " \"returnType\": 4," + + " \"operator\": \"EQUAL\"," + + " \"left\": " + + " {" + + " \"exprType\": \"FIELD_REFERENCE\"," + + " \"dataType\": 15," + + " \"width\": 1," + + " \"colVal\": 0" + + " }," + + " \"right\": " + + " {" + + " \"exprType\": \"LITERAL\"," + + " \"dataType\": 15," + + " \"isNull\": false," + + " \"width\": 1," + + " \"value\": \"c\"" + + " }" + + " }" + + "}"; + + // build factory + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), ImmutableList.of(VARCHAR)) + .row("a").row((String) null).row((String) null).row("a").row("b"); + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.of(filterFunction)); + JoinBridgeManager lookupSourceFactory = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + List probeTypes = ImmutableList.of(VARCHAR); + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); + List probeInput = probePages.row("a").row("b").row((String) null).row("c").build(); + + OperatorFactory joinOperatorFactory = LookupJoinOmniOperators.probeOuterJoin(0, new PlanNodeId("test"), + lookupSourceFactory, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), + Optional.empty(), OptionalInt.of(1), buildSideSetup.getBuildOperatorFactory()); + + // build drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + + // expected + MaterializedResult expected = MaterializedResult + .resultBuilder(taskContext.getSession(), concat(probeTypes, buildPages.getTypesWithoutHash())) + .row("a", "a").row("a", "a").row("b", null).row(null, null).row("c", null).build(); + + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, true, + getHashChannels(probePages, buildPages)); + joinOperatorFactory.noMoreOperators(); + } + + @Test(dataProvider = "hashJoinTestValues") + public void testInnerJoinWithEmptyLookupSource(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + { + TaskContext taskContext = createTaskContext(); + + // build factory + List buildTypes = ImmutableList.of(BIGINT); + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), buildTypes); + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.empty()); + JoinBridgeManager lookupSourceFactoryManager = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + List probeTypes = ImmutableList.of(BIGINT); + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); + + OperatorFactory joinOperatorFactory = LookupJoinOmniOperators.innerJoin(0, new PlanNodeId("test"), + lookupSourceFactoryManager, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), + Optional.empty(), OptionalInt.of(1), buildSideSetup.getBuildOperatorFactory()); + + // drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + Operator operator = joinOperatorFactory.createOperator(driverContext); + + List pages = probePages.row(6L).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, pages); + operator.addInput(offHeapPages.get(0)); + Page outputPage = operator.getOutput(); + assertNull(outputPage); + joinOperatorFactory.noMoreOperators(driverContext.getLifespan()); + } + + @Test(dataProvider = "hashJoinTestValues") + public void testProbeOuterJoinWithEmptyLookupSource(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + { + TaskContext taskContext = createTaskContext(); + // build factory + List buildTypes = ImmutableList.of(VARCHAR); + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), buildTypes); + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.empty()); + JoinBridgeManager lookupSourceFactoryManager = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + List probeTypes = ImmutableList.of(VARCHAR); + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); + List probeInput = probePages + .row("a").row("b").row((String) null).row("c").build(); + + OperatorFactory joinOperatorFactory = LookupJoinOmniOperators.probeOuterJoin(0, new PlanNodeId("test"), + lookupSourceFactoryManager, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), + Optional.empty(), OptionalInt.of(1), buildSideSetup.getBuildOperatorFactory()); + + // build drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + + // expected + MaterializedResult expected = MaterializedResult + .resultBuilder(taskContext.getSession(), concat(probeTypes, buildTypes)).row("a", null).row("b", null) + .row(null, null).row("c", null).build(); + + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, true, + getHashChannels(probePages, buildPages)); + joinOperatorFactory.noMoreOperators(); + } + + @Test(dataProvider = "hashJoinTestValues") + public void testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + { + TaskContext taskContext = createTaskContext(); + + // build factory + List buildTypes = ImmutableList.of(BIGINT); + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), buildTypes).row(1L).row(2L) + .row((String) null).row(3L); + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.empty()); + JoinBridgeManager lookupSourceFactoryManager = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + List probeTypes = ImmutableList.of(BIGINT); + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); + List probeInput = probePages.build(); + + OperatorFactory joinOperatorFactory = LookupJoinOmniOperators.innerJoin(0, new PlanNodeId("test"), + lookupSourceFactoryManager, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), + Optional.empty(), OptionalInt.of(1), buildSideSetup.getBuildOperatorFactory()); + + // build drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + + // expected + MaterializedResult expected = MaterializedResult.resultBuilder(taskContext.getSession(), concat(probeTypes, buildTypes)).build(); + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, true, + getHashChannels(probePages, buildPages)); + joinOperatorFactory.noMoreOperators(); + } + + @Test(dataProvider = "hashJoinTestValues") + public void testInnerJoinWithZeroAndNulls(boolean parallelBuild, boolean probeHashEnabled, + boolean buildHashEnabled) + { + // TODO: failed since taking null as zero now, open when omni-runtime support + TaskContext taskContext = createTaskContext(); + + // build factory + List buildTypes = ImmutableList.of(BIGINT); + RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), buildTypes).row(1L) + .row((String) null).row(0L).row(1L).row(2L); + BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, + Optional.empty()); + JoinBridgeManager lookupSourceFactory = buildSideSetup + .getLookupSourceFactoryManager(); + + // probe factory + List probeTypes = ImmutableList.of(BIGINT); + RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes); + List probeInput = probePages.row(1L).row(2L).row((String) null).row(0L).row(3L).build(); + + OperatorFactory joinOperatorFactory = innerJoinOperatorFactory(lookupSourceFactory, probePages, + buildSideSetup.getBuildOperatorFactory()); + + // build drivers and operators + instantiateBuildDrivers(buildSideSetup, taskContext); + buildLookupSource(buildSideSetup); + + // expected + MaterializedResult expected = MaterializedResult + .resultBuilder(taskContext.getSession(), concat(probeTypes, buildTypes)).row(1L, 1L).row(1L, 1L) + .row(2L, 2L).row(0L, 0L).build(); + + DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, true, + getHashChannels(probePages, buildPages)); + joinOperatorFactory.noMoreOperators(); + } + + private TaskContext createTaskContext() + { + return TestingTaskContext.createTaskContext(executor, scheduledExecutor, TEST_SESSION); + } + + private static List getHashChannels(RowPagesBuilder probe, RowPagesBuilder build) + { + ImmutableList.Builder hashChannels = ImmutableList.builder(); + if (probe.getHashChannel().isPresent()) { + hashChannels.add(probe.getHashChannel().get()); + } + if (build.getHashChannel().isPresent()) { + hashChannels.add(probe.getTypes().size() + build.getHashChannel().get()); + } + return hashChannels.build(); + } + + private OperatorFactory innerJoinOperatorFactory( + JoinBridgeManager lookupSourceFactoryManager, RowPagesBuilder probePages, + OperatorFactory operatorFactory) + { + return LookupJoinOmniOperators.innerJoin(0, new PlanNodeId("test"), lookupSourceFactoryManager, + probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), Optional.empty(), + OptionalInt.of(1), (HashBuilderOmniOperator.HashBuilderOmniOperatorFactory) operatorFactory); + } + + private BuildSideSetup setupBuildSide(boolean parallelBuild, TaskContext taskContext, List hashChannels, + RowPagesBuilder buildPages, Optional filterFunction) + { + int partitionCount = parallelBuild ? PARTITION_COUNT : 1; + LocalExchangeFactory localExchangeFactory = new LocalExchangeFactory(FIXED_HASH_DISTRIBUTION, partitionCount, + buildPages.getTypes(), hashChannels, buildPages.getHashChannel(), UNGROUPED_EXECUTION, + new DataSize(32, DataSize.Unit.MEGABYTE)); + LocalExchangeSinkFactoryId localExchangeSinkFactoryId = localExchangeFactory.newSinkFactoryId(); + localExchangeFactory.noMoreSinkFactories(); + + // collect input data into the partitioned exchange + DriverContext collectDriverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); + + List builds = buildPages.build(); + List toOffHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, builds); + ValuesOperatorFactory valuesOperatorFactory = new ValuesOperatorFactory(0, new PlanNodeId("values"), + toOffHeapPages); + LocalExchangeSinkOperatorFactory sinkOperatorFactory = new LocalExchangeSinkOperatorFactory( + localExchangeFactory, 1, new PlanNodeId("sink"), localExchangeSinkFactoryId, Function.identity()); + Driver sourceDriver = Driver.createDriver(collectDriverContext, + valuesOperatorFactory.createOperator(collectDriverContext), + sinkOperatorFactory.createOperator(collectDriverContext)); + valuesOperatorFactory.noMoreOperators(); + sinkOperatorFactory.noMoreOperators(); + + while (!sourceDriver.isFinished()) { + sourceDriver.process(); + } + + // build side operator factories + LocalExchangeSourceOperatorFactory sourceOperatorFactory = new LocalExchangeSourceOperatorFactory(0, + new PlanNodeId("source"), localExchangeFactory, 1); + + JoinBridgeManager lookupSourceFactoryManager = JoinBridgeManager + .lookupAllAtOnce(new PartitionedLookupSourceFactory(buildPages.getTypes(), + rangeList(buildPages.getTypes().size()).stream().map(buildPages.getTypes()::get) + .collect(toImmutableList()), + hashChannels.stream().map(buildPages.getTypes()::get).collect(toImmutableList()), + partitionCount, requireNonNull(ImmutableMap.of(), "layout is null"), false, false)); + + HashBuilderOmniOperator.HashBuilderOmniOperatorFactory builderOmniOperatorFactory = new HashBuilderOmniOperator.HashBuilderOmniOperatorFactory( + 1, new PlanNodeId("build"), lookupSourceFactoryManager, buildPages.getTypes(), + rangeList(buildPages.getTypes().size()), hashChannels, + buildPages.getHashChannel().map(OptionalInt::of).orElse(OptionalInt.empty()), + filterFunction, + Optional.empty(), ImmutableList.of(), partitionCount); + + return new BuildSideSetup(lookupSourceFactoryManager, builderOmniOperatorFactory, sourceOperatorFactory, + partitionCount); + } + + private void instantiateBuildDrivers(BuildSideSetup buildSideSetup, TaskContext taskContext) + { + PipelineContext buildPipeline = taskContext.addPipelineContext(1, true, true, false); + List buildDrivers = new ArrayList<>(); + List buildOperators = new ArrayList<>(); + for (int i = 0; i < buildSideSetup.getPartitionCount(); i++) { + DriverContext buildDriverContext = buildPipeline.addDriverContext(); + HashBuilderOmniOperator buildOperator = (HashBuilderOmniOperator) buildSideSetup.getBuildOperatorFactory() + .createOperator(buildDriverContext); + Driver driver = Driver.createDriver(buildDriverContext, + buildSideSetup.getBuildSideSourceOperatorFactory().createOperator(buildDriverContext), + buildOperator); + buildDrivers.add(driver); + buildOperators.add(buildOperator); + } + buildSideSetup.getBuildOperatorFactory().noMoreOperators(); + + buildSideSetup.setDriversAndOperators(buildDrivers, buildOperators); + } + + private void buildLookupSource(BuildSideSetup buildSideSetup) + { + requireNonNull(buildSideSetup, "buildSideSetup is null"); + + LookupSourceFactory lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager().getJoinBridge(Lifespan.taskWide()); + Future lookupSourceProvider = lookupSourceFactory.createLookupSourceProvider(); + List buildDrivers = buildSideSetup.getBuildDrivers(); + + while (!lookupSourceProvider.isDone()) { + for (Driver buildDriver : buildDrivers) { + buildDriver.process(); + } + } + getFutureValue(lookupSourceProvider).close(); + + for (Driver buildDriver : buildDrivers) { + runDriverInThread(executor, buildDriver); + } + } + + /** + * Runs Driver in another thread until it is finished + */ + private static void runDriverInThread(ExecutorService executor, Driver driver) + { + executor.execute(() -> { + if (!driver.isFinished()) { + try { + driver.process(); + } + catch (PrestoException e) { + driver.getDriverContext().failed(e); + throw e; + } + runDriverInThread(executor, driver); + } + }); + } + + private static OptionalInt getHashChannelAsInt(RowPagesBuilder probePages) + { + return probePages.getHashChannel() + .map(OptionalInt::of).orElse(OptionalInt.empty()); + } + + private static List rangeList(int endExclusive) + { + return IntStream.range(0, endExclusive) + .boxed() + .collect(toImmutableList()); + } + + private static List concat(List initialElements, List moreElements) + { + return ImmutableList.copyOf(Iterables.concat(initialElements, moreElements)); + } + + private static class BuildSideSetup + { + private final JoinBridgeManager lookupSourceFactoryManager; + private final HashBuilderOmniOperator.HashBuilderOmniOperatorFactory buildOperatorFactory; + private final LocalExchangeSourceOperatorFactory buildSideSourceOperatorFactory; + private final int partitionCount; + private List buildDrivers; + private List buildOperators; + + BuildSideSetup(JoinBridgeManager lookupSourceFactoryManager, + HashBuilderOmniOperator.HashBuilderOmniOperatorFactory buildOperatorFactory, + LocalExchangeSourceOperatorFactory buildSideSourceOperatorFactory, int partitionCount) + { + this.lookupSourceFactoryManager = requireNonNull(lookupSourceFactoryManager, "lookupSourceFactoryManager is null"); + this.buildOperatorFactory = requireNonNull(buildOperatorFactory, "buildOperatorFactory is null"); + this.buildSideSourceOperatorFactory = buildSideSourceOperatorFactory; + this.partitionCount = partitionCount; + } + + void setDriversAndOperators(List buildDrivers, List buildOperators) + { + checkArgument(buildDrivers.size() == buildOperators.size()); + this.buildDrivers = ImmutableList.copyOf(buildDrivers); + this.buildOperators = ImmutableList.copyOf(buildOperators); + } + + JoinBridgeManager getLookupSourceFactoryManager() + { + return lookupSourceFactoryManager; + } + + HashBuilderOmniOperator.HashBuilderOmniOperatorFactory getBuildOperatorFactory() + { + return buildOperatorFactory; + } + + HashBuilderOmniOperator.HashBuilderOmniOperatorFactory getOperatorFactory() + { + return buildOperatorFactory; + } + + public LocalExchangeSourceOperatorFactory getBuildSideSourceOperatorFactory() + { + return buildSideSourceOperatorFactory; + } + + public int getPartitionCount() + { + return partitionCount; + } + + List getBuildDrivers() + { + checkState(buildDrivers != null, "buildDrivers is not initialized yet"); + return buildDrivers; + } + + List getBuildOperators() + { + checkState(buildOperators != null, "buildDrivers is not initialized yet"); + return buildOperators; + } + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestLimitOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestLimitOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..8c59252bab78c8f44cd3638267bfacac4dda3dcc --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestLimitOmniOperator.java @@ -0,0 +1,116 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import io.prestosql.operator.DriverContext; +import io.prestosql.operator.OperatorFactory; +import io.prestosql.spi.Page; +import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.testing.MaterializedResult; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; + +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.prestosql.RowPagesBuilder.rowPagesBuilder; +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.operator.OperatorAssertion.assertOperatorEquals; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.testing.MaterializedResult.resultBuilder; +import static io.prestosql.testing.TestingTaskContext.createTaskContext; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static nova.hetu.olk.tool.OperatorUtils.transferToOffHeapPages; + +@Test(singleThreaded = true) +public class TestLimitOmniOperator +{ + private ExecutorService executor; + private ScheduledExecutorService scheduledExecutor; + private DriverContext driverContext; + + @BeforeMethod + public void setUp() + { + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + } + + @AfterMethod(alwaysRun = true) + public void tearDown() + { + executor.shutdownNow(); + scheduledExecutor.shutdownNow(); + } + + @Test + public void testBasicLimit() + { + List input = rowPagesBuilder(BIGINT, DOUBLE).row(1L, 0.1).row(2L, 0.2).pageBreak().row(-1L, -0.1) + .row(4L, 0.4).pageBreak().row(5L, 0.5).row(4L, 0.41).row(6L, 0.6).pageBreak().build(); + // transfer on-heap page to off-heap + List offHeapInput = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + OperatorFactory operatorFactory = new LimitOmniOperator.LimitOmniOperatorFactory(0, new PlanNodeId("test"), + 3, ImmutableList.of(BIGINT, DOUBLE)); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, DOUBLE).row(1L, 0.1) + .row(2L, 0.2).row(-1L, -0.1).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapInput, expected); + } + + @Test + public void testLimitNull() + { + List input = rowPagesBuilder(BIGINT, DOUBLE).row(1L, 0.1).row(null, 0.2).pageBreak().row(-1L, -0.1) + .row(null, null).pageBreak().row(5L, 0.5).row(4L, 0.41).row(6L, 0.6).pageBreak().build(); + // transfer on-heap page to off-heap + List offHeapInput = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + OperatorFactory operatorFactory = new LimitOmniOperator.LimitOmniOperatorFactory(0, new PlanNodeId("test"), + 5, ImmutableList.of(BIGINT, DOUBLE)); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, DOUBLE).row(1L, 0.1) + .row(null, 0.2).row(-1L, -0.1).row(null, null).row(5L, 0.5).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapInput, expected); + } + + @Test + public void testLimitLess() + { + List input = rowPagesBuilder(BIGINT, DOUBLE).row(1L, 0.1).row(-1L, -0.1).build(); + // transfer on-heap page to off-heap + List offHeapInput = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + OperatorFactory operatorFactory = new LimitOmniOperator.LimitOmniOperatorFactory(0, new PlanNodeId("test"), + 5, ImmutableList.of(BIGINT, DOUBLE)); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, DOUBLE).row(1L, 0.1).row(-1L, -0.1).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapInput, expected); + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestLocalExchangeSourceOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestLocalExchangeSourceOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..55ea0e79e9cc37aebf06fe97ae3017c41ce34be1 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestLocalExchangeSourceOmniOperator.java @@ -0,0 +1,157 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; +import io.airlift.units.DataSize; +import io.prestosql.Session; +import io.prestosql.execution.Lifespan; +import io.prestosql.operator.DriverContext; +import io.prestosql.operator.Operator; +import io.prestosql.operator.OperatorFactory; +import io.prestosql.operator.PipelineContext; +import io.prestosql.operator.exchange.LocalExchange; +import io.prestosql.operator.exchange.LocalExchangeSink; +import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.planner.PartitioningHandle; +import nova.hetu.olk.operator.localexchange.LocalExchangeSourceOmniOperator; +import nova.hetu.olk.operator.localexchange.OmniLocalExchange; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ScheduledExecutorService; + +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.prestosql.SessionTestUtils.TEST_SNAPSHOT_SESSION; +import static io.prestosql.operator.PipelineExecutionStrategy.UNGROUPED_EXECUTION; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; +import static io.prestosql.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; +import static io.prestosql.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static io.prestosql.sql.planner.SystemPartitioningHandle.FIXED_PASSTHROUGH_DISTRIBUTION; +import static io.prestosql.testing.TestingTaskContext.createTaskContext; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestLocalExchangeSourceOmniOperator +{ + private static final List TYPES = ImmutableList.of(BIGINT); + private static final DataSize LOCAL_EXCHANGE_MAX_BUFFERED_BYTES = new DataSize(32, DataSize.Unit.MEGABYTE); + + private PipelineContext pipelineContext; + + @BeforeMethod + public void setup() + { + pipelineContext = null; + } + + @DataProvider + public static Object[][] markerDistributions() + { + return new Object[][]{{FIXED_BROADCAST_DISTRIBUTION}, {FIXED_ARBITRARY_DISTRIBUTION}, + {FIXED_PASSTHROUGH_DISTRIBUTION}, {FIXED_HASH_DISTRIBUTION}}; + } + + @Test(dataProvider = "markerDistributions") + public void testGetInputChannels(PartitioningHandle partitioningHandle) + { + List partitionChannels = partitioningHandle == FIXED_HASH_DISTRIBUTION + ? ImmutableList.of(0) + : ImmutableList.of(); + LocalExchange.LocalExchangeFactory localExchangeFactory = new OmniLocalExchange.OmniLocalExchangeFactory( + partitioningHandle, 2, TYPES, partitionChannels, Optional.empty(), UNGROUPED_EXECUTION, + LOCAL_EXCHANGE_MAX_BUFFERED_BYTES); + LocalExchange.LocalExchangeSinkFactoryId localExchangeSinkFactoryId = localExchangeFactory.newSinkFactoryId(); + localExchangeFactory.noMoreSinkFactories(); + + LocalExchangeSourceOmniOperator operatorA = createOperator(localExchangeFactory, 2, TEST_SNAPSHOT_SESSION); + LocalExchangeSourceOmniOperator operatorB = createOperator(localExchangeFactory, 2, TEST_SNAPSHOT_SESSION); + + LocalExchange exchange = localExchangeFactory.getLocalExchange(Lifespan.taskWide()); + + assertFalse(operatorA.getInputChannels().isPresent()); + assertFalse(operatorB.getInputChannels().isPresent()); + + LocalExchange.LocalExchangeSinkFactory sinkFactory = exchange.getSinkFactory(localExchangeSinkFactoryId); + final String sinkAId = "sinkA"; + LocalExchangeSink sinkA = sinkFactory.createSink(sinkAId); + assertFalse(operatorA.getInputChannels().isPresent()); + assertFalse(operatorB.getInputChannels().isPresent()); + + final String sinkBId = "sinkB"; + LocalExchangeSink sinkB = sinkFactory.createSink(sinkBId); + assertTrue(operatorA.getInputChannels().isPresent()); + assertEquals(operatorA.getInputChannels().get(), Sets.newHashSet(sinkAId, sinkBId)); + assertTrue(operatorB.getInputChannels().isPresent()); + assertEquals(operatorB.getInputChannels().get(), Sets.newHashSet(sinkAId, sinkBId)); + + sinkFactory.close(); + sinkFactory.noMoreSinkFactories(); + + sinkA.finish(); + sinkB.finish(); + } + + private LocalExchangeSourceOmniOperator createOperator(LocalExchange.LocalExchangeFactory localExchangeFactory, + int totalInputChannels, Session session) + { + return createOperator(localExchangeFactory, totalInputChannels, session, 0); + } + + private LocalExchangeSourceOmniOperator createOperator(LocalExchange.LocalExchangeFactory localExchangeFactory, + int totalInputChannels, Session session, int driverId) + { + if (pipelineContext == null) { + ScheduledExecutorService scheduler = newScheduledThreadPool(4, daemonThreadsNamed("test-%s")); + ScheduledExecutorService scheduledExecutor = newScheduledThreadPool(2, + daemonThreadsNamed("test-scheduledExecutor-%s")); + OperatorFactory operatorFactory = new LocalExchangeSourceOmniOperator.LocalExchangeSourceOmniOperatorFactory( + 0, new PlanNodeId("test"), localExchangeFactory, totalInputChannels, TYPES); + + pipelineContext = createTaskContext(scheduler, scheduledExecutor, session).addPipelineContext(0, true, true, false); + } + DriverContext driverContext = pipelineContext.addDriverContext(Lifespan.taskWide(), driverId); + + OperatorFactory operatorFactory = new LocalExchangeSourceOmniOperator.LocalExchangeSourceOmniOperatorFactory(0, + new PlanNodeId("test"), localExchangeFactory, totalInputChannels, TYPES); + + boolean duplicated = true; + try { + operatorFactory = operatorFactory.duplicate(); + } + catch (UnsupportedOperationException e) { + duplicated = false; + } + assertEquals(duplicated, false); + assertEquals(((LocalExchangeSourceOmniOperator.LocalExchangeSourceOmniOperatorFactory) operatorFactory) + .getLocalExchangeFactory(), localExchangeFactory); + assertEquals(operatorFactory.isExtensionOperatorFactory(), true); + assertEquals(operatorFactory.getSourceTypes(), TYPES); + + Operator operator = operatorFactory.createOperator(driverContext); + assertEquals(operator.getOperatorContext().getOperatorStats().getSystemMemoryReservation().toBytes(), 0); + return (LocalExchangeSourceOmniOperator) operator; + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestLocalMergeSourceOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestLocalMergeSourceOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..36a590a9233e78ecad2f19820de3ac593098bcd6 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestLocalMergeSourceOmniOperator.java @@ -0,0 +1,301 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import io.airlift.units.DataSize; +import io.prestosql.operator.DriverContext; +import io.prestosql.operator.Operator; +import io.prestosql.operator.PipelineExecutionStrategy; +import io.prestosql.operator.exchange.LocalExchange; +import io.prestosql.operator.exchange.LocalExchangeSink; +import io.prestosql.spi.Page; +import io.prestosql.spi.block.SortOrder; +import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.gen.OrderingCompiler; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.prestosql.RowPagesBuilder.rowPagesBuilder; +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.operator.OperatorAssertion.assertOperatorIsBlocked; +import static io.prestosql.operator.OperatorAssertion.assertOperatorIsUnblocked; +import static io.prestosql.operator.PageAssertions.assertPageEquals; +import static io.prestosql.spi.block.SortOrder.ASC_NULLS_FIRST; +import static io.prestosql.spi.block.SortOrder.DESC_NULLS_FIRST; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.sql.planner.SystemPartitioningHandle.FIXED_PASSTHROUGH_DISTRIBUTION; +import static io.prestosql.testing.TestingTaskContext.createTaskContext; +import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestLocalMergeSourceOmniOperator +{ + private AtomicInteger operatorId = new AtomicInteger(); + + private ScheduledExecutorService executor; + private LocalExchange.LocalExchangeFactory localExchangeFactory; + private OrderingCompiler orderingCompiler; + + @BeforeMethod + public void setUp() + { + executor = newSingleThreadScheduledExecutor(daemonThreadsNamed("test-local-merge-source-omni-operator-%s")); + orderingCompiler = new OrderingCompiler(); + } + + @AfterMethod(alwaysRun = true) + public void tearDown() + { + orderingCompiler = null; + + executor.shutdownNow(); + executor = null; + + localExchangeFactory = null; + } + + @Test + public void testSingleStream() throws Exception + { + List types = ImmutableList.of(BIGINT, BIGINT); + int defaultConcurrency = 2; + PipelineExecutionStrategy exchangeSourcePipelineExecutionStrategy = PipelineExecutionStrategy.UNGROUPED_EXECUTION; + + localExchangeFactory = new LocalExchange.LocalExchangeFactory(FIXED_PASSTHROUGH_DISTRIBUTION, + defaultConcurrency, types, ImmutableList.of(), Optional.empty(), + exchangeSourcePipelineExecutionStrategy, new DataSize(32, DataSize.Unit.MEGABYTE)); + LocalExchange.LocalExchangeSinkFactoryId sinkFactoryId = localExchangeFactory.newSinkFactoryId(); + localExchangeFactory.noMoreSinkFactories(); + + DriverContext driverContext = createTaskContext(executor, executor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + + LocalExchange localExchange = localExchangeFactory.getLocalExchange(driverContext.getLifespan()); + LocalExchange.LocalExchangeSinkFactory sinkFactory = localExchange.getSinkFactory(sinkFactoryId); + + List sinkList = new ArrayList<>(); + for (int i = 0; i < defaultConcurrency; i++) { + sinkList.add(sinkFactory.createSink("")); + } + + sinkFactory.close(); + sinkFactory.noMoreSinkFactories(); + + LocalMergeSourceOmniOperator operator = createLocalMergeSourceOmniOperator(types, ImmutableList.of(1), + ImmutableList.of(0, 1), ImmutableList.of(ASC_NULLS_FIRST, ASC_NULLS_FIRST), driverContext); + + List input = rowPagesBuilder(types).row(1, 1).row(2, 2).pageBreak().row(3, 3).row(4, 4).build(); + + assertFalse(operator.isFinished()); + assertOperatorIsBlocked(operator); + sinkList.get(0).addPage(input.get(0), null); + assertOperatorIsUnblocked(operator); + + sinkList.get(1).addPage(input.get(1), null); + assertOperatorIsUnblocked(operator); + + assertNull(operator.getOutput()); + sinkFinish(sinkList); + assertOperatorIsUnblocked(operator); + + Page expected = rowPagesBuilder(BIGINT).row(1).row(2).row(3).row(4).build().get(0); + assertPageEquals(ImmutableList.of(BIGINT), getOnlyElement(pullAvailablePages(operator)), expected); + operator.close(); + } + + @Test + public void testMergeDifferentTypes() throws Exception + { + ImmutableList types = ImmutableList.of(BIGINT, INTEGER); + int defaultConcurrency = 2; + PipelineExecutionStrategy exchangeSourcePipelineExecutionStrategy = PipelineExecutionStrategy.UNGROUPED_EXECUTION; + localExchangeFactory = new LocalExchange.LocalExchangeFactory(FIXED_PASSTHROUGH_DISTRIBUTION, + defaultConcurrency, types, ImmutableList.of(), Optional.empty(), + exchangeSourcePipelineExecutionStrategy, new DataSize(32, DataSize.Unit.MEGABYTE)); + LocalExchange.LocalExchangeSinkFactoryId sinkFactoryId = localExchangeFactory.newSinkFactoryId(); + localExchangeFactory.noMoreSinkFactories(); + + DriverContext driverContext = createTaskContext(executor, executor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + + LocalExchange localExchange = localExchangeFactory.getLocalExchange(driverContext.getLifespan()); + LocalExchange.LocalExchangeSinkFactory sinkFactory = localExchange.getSinkFactory(sinkFactoryId); + + List sinkList = new ArrayList<>(); + for (int i = 0; i < defaultConcurrency; i++) { + sinkList.add(sinkFactory.createSink("")); + } + + sinkFactory.close(); + sinkFactory.noMoreSinkFactories(); + + LocalMergeSourceOmniOperator operator = createLocalMergeSourceOmniOperator(types, ImmutableList.of(1, 0), + ImmutableList.of(1, 0), ImmutableList.of(DESC_NULLS_FIRST, ASC_NULLS_FIRST), driverContext); + + List input1 = rowPagesBuilder(types).row(0, null).row(1, 4).row(2, 3).build(); + List input2 = rowPagesBuilder(types).row(null, 5).row(2, 5).row(4, 3).build(); + + assertFalse(operator.isFinished()); + assertOperatorIsBlocked(operator); + sinkList.get(0).addPage(input1.get(0), null); + assertOperatorIsUnblocked(operator); + + sinkList.get(1).addPage(input2.get(0), null); + assertOperatorIsUnblocked(operator); + + assertNull(operator.getOutput()); + sinkFinish(sinkList); + assertOperatorIsUnblocked(operator); + + ImmutableList outputTypes = ImmutableList.of(INTEGER, BIGINT); + Page expected = rowPagesBuilder(outputTypes).row(null, 0).row(5, null).row(5, 2).row(4, 1).row(3, 2).row(3, 4) + .build().get(0); + + assertPageEquals(outputTypes, getOnlyElement(pullAvailablePages(operator)), expected); + operator.close(); + } + + @Test + public void testMultipleStreamsSameOutputColumns() throws Exception + { + List types = ImmutableList.of(BIGINT, BIGINT, BIGINT); + int defaultConcurrency = 8; + PipelineExecutionStrategy exchangeSourcePipelineExecutionStrategy = PipelineExecutionStrategy.UNGROUPED_EXECUTION; + localExchangeFactory = new LocalExchange.LocalExchangeFactory(FIXED_PASSTHROUGH_DISTRIBUTION, + defaultConcurrency, types, ImmutableList.of(), Optional.empty(), + exchangeSourcePipelineExecutionStrategy, new DataSize(32, DataSize.Unit.MEGABYTE)); + LocalExchange.LocalExchangeSinkFactoryId sinkFactoryId = localExchangeFactory.newSinkFactoryId(); + localExchangeFactory.noMoreSinkFactories(); + + DriverContext driverContext = createTaskContext(executor, executor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + + LocalExchange localExchange = localExchangeFactory.getLocalExchange(driverContext.getLifespan()); + LocalExchange.LocalExchangeSinkFactory sinkFactory = localExchange.getSinkFactory(sinkFactoryId); + + List sinkList = new ArrayList<>(); + for (int i = 0; i < defaultConcurrency; i++) { + sinkList.add(sinkFactory.createSink("")); + } + + sinkFactory.close(); + sinkFactory.noMoreSinkFactories(); + + LocalMergeSourceOmniOperator operator = createLocalMergeSourceOmniOperator(types, ImmutableList.of(0, 1, 2), + ImmutableList.of(0), ImmutableList.of(ASC_NULLS_FIRST), driverContext); + assertOperatorIsBlocked(operator); + + List input1 = rowPagesBuilder(types).row(1, 1, 2).row(8, 1, 1).row(19, 1, 3).row(27, 1, 4).row(41, 2, 5) + .pageBreak().row(55, 1, 2).row(89, 1, 3).row(101, 1, 4).row(202, 1, 3).row(399, 2, 2) + .pageBreak().row(400, 1, 1).row(401, 1, 7).row(402, 1, 6).build(); + List input2 = rowPagesBuilder(types).row(2, 1, 2).row(8, 1, 1).row(19, 1, 3).row(25, 1, 4).row(26, 2, 5) + .pageBreak().row(56, 1, 2).row(66, 1, 3).row(77, 1, 4).row(88, 1, 3).row(99, 2, 2) + .pageBreak().row(99, 1, 1).row(100, 1, 7).row(100, 1, 6).build(); + List input3 = rowPagesBuilder(types).row(88, 1, 3).row(89, 1, 3).row(90, 1, 3).row(91, 1, 4).row(92, 2, 5) + .pageBreak().row(93, 1, 2).row(94, 1, 3).row(95, 1, 4).row(97, 1, 3).row(98, 2, 2).build(); + + assertOperatorIsBlocked(operator); + + sinkList.get(0).addPage(input1.get(0), null); + sinkList.get(1).addPage(input1.get(1), null); + sinkList.get(2).addPage(input1.get(2), null); + sinkList.get(3).addPage(input2.get(0), null); + sinkList.get(4).addPage(input2.get(1), null); + sinkList.get(5).addPage(input2.get(2), null); + sinkList.get(6).addPage(input3.get(0), null); + sinkList.get(7).addPage(input3.get(1), null); + + assertOperatorIsUnblocked(operator); + assertNull(operator.getOutput()); + sinkFinish(sinkList); + assertOperatorIsUnblocked(operator); + + Page expected = rowPagesBuilder(types).row(1, 1, 2).row(2, 1, 2).row(8, 1, 1).row(8, 1, 1).row(19, 1, 3) + .row(19, 1, 3).row(25, 1, 4).row(26, 2, 5).row(27, 1, 4).row(41, 2, 5).row(55, 1, 2).row(56, 1, 2) + .row(66, 1, 3).row(77, 1, 4).row(88, 1, 3).row(88, 1, 3).row(89, 1, 3).row(89, 1, 3).row(90, 1, 3) + .row(91, 1, 4).row(92, 2, 5).row(93, 1, 2).row(94, 1, 3).row(95, 1, 4).row(97, 1, 3).row(98, 2, 2) + .row(99, 2, 2).row(99, 1, 1).row(100, 1, 6).row(100, 1, 7).row(101, 1, 4).row(202, 1, 3).row(399, 2, 2) + .row(400, 1, 1).row(401, 1, 7).row(402, 1, 6).build().get(0); + + assertPageEquals(types, getOnlyElement(pullAvailablePages(operator)), expected); + operator.close(); + } + + private LocalMergeSourceOmniOperator createLocalMergeSourceOmniOperator(List sourceTypes, + List outputChannels, List sortChannels, List sortOrder, + DriverContext driverContext) + { + int mergeOperatorId = operatorId.getAndIncrement(); + int orderByOmniId = mergeOperatorId; + LocalMergeSourceOmniOperator.LocalMergeSourceOmniOperatorFactory factory = new LocalMergeSourceOmniOperator.LocalMergeSourceOmniOperatorFactory( + mergeOperatorId, orderByOmniId, new PlanNodeId("plan_node_id" + mergeOperatorId), localExchangeFactory, + sourceTypes, orderingCompiler, sortChannels, sortOrder, outputChannels); + + return (LocalMergeSourceOmniOperator) factory.createOperator(driverContext); + } + + private static List pullAvailablePages(Operator operator) throws InterruptedException + { + long endTime = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + List outputPages = new ArrayList<>(); + + assertOperatorIsUnblocked(operator); + + while (!operator.isFinished() && System.nanoTime() - endTime < 0) { + assertFalse(operator.needsInput()); + Page outputPage = operator.getOutput(); + if (outputPage != null) { + outputPages.add(outputPage); + } + else { + Thread.sleep(10); + } + } + + // verify state + assertFalse(operator.needsInput(), "Operator still wants input"); + assertTrue(operator.isFinished(), "Expected operator to be finished"); + + return outputPages; + } + + private void sinkFinish(List sinkList) + { + if (sinkList == null) { + return; + } + + for (LocalExchangeSink sink : sinkList) { + sink.finish(); + } + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestMergeOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestMergeOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..db33afedf933909a9ecc82f7323dbf6e309446f2 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestMergeOmniOperator.java @@ -0,0 +1,326 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.collect.ImmutableList; +import com.google.inject.Injector; +import com.google.inject.Key; +import io.airlift.bootstrap.Bootstrap; +import io.airlift.discovery.client.ServiceSelector; +import io.airlift.discovery.client.testing.TestingDiscoveryModule; +import io.airlift.http.client.HttpClient; +import io.airlift.http.client.testing.TestingHttpClient; +import io.airlift.http.server.testing.TestingHttpServerModule; +import io.airlift.jaxrs.JaxrsModule; +import io.airlift.jmx.testing.TestingJmxModule; +import io.airlift.json.JsonModule; +import io.airlift.node.testing.TestingNodeModule; +import io.airlift.tracetoken.TraceTokenModule; +import io.hetu.core.transport.execution.buffer.PagesSerdeFactory; +import io.prestosql.execution.Lifespan; +import io.prestosql.execution.QueryManagerConfig; +import io.prestosql.failuredetector.FailureDetectorModule; +import io.prestosql.failuredetector.HeartbeatFailureDetector; +import io.prestosql.failuredetector.TestHeartbeatFailureDetector; +import io.prestosql.metadata.Split; +import io.prestosql.operator.DriverContext; +import io.prestosql.operator.ExchangeClientConfig; +import io.prestosql.operator.ExchangeClientFactory; +import io.prestosql.operator.ExchangeOperator; +import io.prestosql.operator.Operator; +import io.prestosql.operator.TestingExchangeHttpClientHandler; +import io.prestosql.operator.TestingTaskBuffer; +import io.prestosql.server.InternalCommunicationConfig; +import io.prestosql.spi.Page; +import io.prestosql.spi.block.SortOrder; +import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.type.Type; +import io.prestosql.split.RemoteSplit; +import io.prestosql.sql.gen.OrderingCompiler; +import io.prestosql.testing.TestingPagesSerdeFactory; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.net.URI; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.airlift.configuration.ConfigBinder.configBinder; +import static io.airlift.discovery.client.DiscoveryBinder.discoveryBinder; +import static io.airlift.discovery.client.ServiceTypes.serviceType; +import static io.airlift.jaxrs.JaxrsBinder.jaxrsBinder; +import static io.prestosql.RowPagesBuilder.rowPagesBuilder; +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.operator.OperatorAssertion.assertOperatorIsBlocked; +import static io.prestosql.operator.OperatorAssertion.assertOperatorIsUnblocked; +import static io.prestosql.operator.PageAssertions.assertPageEquals; +import static io.prestosql.spi.block.SortOrder.ASC_NULLS_FIRST; +import static io.prestosql.spi.block.SortOrder.DESC_NULLS_FIRST; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.testing.TestingTaskContext.createTaskContext; +import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestMergeOmniOperator +{ + private static final String TASK_1_ID = "task1"; + private static final String TASK_2_ID = "task2"; + private static final String TASK_3_ID = "task3"; + + private AtomicInteger operatorId = new AtomicInteger(); + + private ScheduledExecutorService executor; + private PagesSerdeFactory serdeFactory; + private HttpClient httpClient; + private ExchangeClientFactory exchangeClientFactory; + private OrderingCompiler orderingCompiler; + + private LoadingCache taskBuffers; + + @BeforeMethod + public void setUp() + { + Bootstrap app = new Bootstrap(new TestingNodeModule(), new TestingJmxModule(), new TestingDiscoveryModule(), + new TestingHttpServerModule(), new TraceTokenModule(), new JsonModule(), new JaxrsModule(), + new FailureDetectorModule(), binder -> { + configBinder(binder).bindConfig(InternalCommunicationConfig.class); + configBinder(binder).bindConfig(QueryManagerConfig.class); + discoveryBinder(binder).bindSelector("presto"); + discoveryBinder(binder).bindHttpAnnouncement("presto"); + + // Jersey with jetty 9 requires at least one resource + // todo add a dummy resource to airlift jaxrs in this case + jaxrsBinder(binder).bind(TestHeartbeatFailureDetector.FooResource.class); + }); + + Injector injector = app.strictConfig().doNotInitializeLogging().quiet().initialize(); + + ServiceSelector selector = injector.getInstance(Key.get(ServiceSelector.class, serviceType("presto"))); + assertEquals(selector.selectAllServices().size(), 1); + + HeartbeatFailureDetector detector = injector.getInstance(HeartbeatFailureDetector.class); + executor = newSingleThreadScheduledExecutor(daemonThreadsNamed("test-merge-omni-operator-%s")); + serdeFactory = new TestingPagesSerdeFactory(); + + taskBuffers = CacheBuilder.newBuilder().build(CacheLoader.from(TestingTaskBuffer::new)); + httpClient = new TestingHttpClient(new TestingExchangeHttpClientHandler(taskBuffers), executor); + ExchangeClientConfig exchangeClientConfig = new ExchangeClientConfig(); + exchangeClientFactory = new ExchangeClientFactory(exchangeClientConfig, httpClient, executor, detector); + orderingCompiler = new OrderingCompiler(); + } + + @AfterMethod(alwaysRun = true) + public void tearDown() + { + serdeFactory = null; + orderingCompiler = null; + + httpClient.close(); + httpClient = null; + + executor.shutdownNow(); + executor = null; + + exchangeClientFactory.stop(); + exchangeClientFactory = null; + } + + @Test + public void testSingleStream() throws Exception + { + List types = ImmutableList.of(BIGINT, BIGINT); + + MergeOmniOperator operator = createMergeOmniOperator(types, ImmutableList.of(1), ImmutableList.of(0, 1), + ImmutableList.of(ASC_NULLS_FIRST, ASC_NULLS_FIRST)); + assertFalse(operator.isFinished()); + assertFalse(operator.isBlocked().isDone()); + + operator.addSplit(createRemoteSplit(TASK_1_ID)); + assertFalse(operator.isFinished()); + assertFalse(operator.isBlocked().isDone()); + + operator.noMoreSplits(); + + List input = rowPagesBuilder(types).row(1, 1).row(2, 2).pageBreak().row(3, 3).row(4, 4).build(); + + assertNull(operator.getOutput()); + assertFalse(operator.isFinished()); + assertOperatorIsBlocked(operator); + taskBuffers.getUnchecked(TASK_1_ID).addPage(input.get(0), false); + assertOperatorIsUnblocked(operator); + + assertNull(operator.getOutput()); + assertOperatorIsBlocked(operator); + taskBuffers.getUnchecked(TASK_1_ID).addPage(input.get(1), true); + assertOperatorIsUnblocked(operator); + + Page expected = rowPagesBuilder(BIGINT).row(1).row(2).row(3).row(4).build().get(0); + assertPageEquals(ImmutableList.of(BIGINT), getOnlyElement(pullAvailablePages(operator)), expected); + operator.close(); + } + + @Test + public void testMergeDifferentTypes() throws Exception + { + ImmutableList types = ImmutableList.of(BIGINT, INTEGER); + MergeOmniOperator operator = createMergeOmniOperator(types, ImmutableList.of(1, 0), ImmutableList.of(1, 0), + ImmutableList.of(DESC_NULLS_FIRST, ASC_NULLS_FIRST)); + operator.addSplit(createRemoteSplit(TASK_1_ID)); + operator.addSplit(createRemoteSplit(TASK_2_ID)); + operator.noMoreSplits(); + + List task1Pages = rowPagesBuilder(types).row(0, null).row(1, 4).row(2, 3).build(); + + List task2Pages = rowPagesBuilder(types).row(null, 5).row(2, 5).row(4, 3).build(); + + // blocked on first data source + assertNull(operator.getOutput()); + assertOperatorIsBlocked(operator); + taskBuffers.getUnchecked(TASK_1_ID).addPages(task1Pages, true); + assertOperatorIsUnblocked(operator); + + // blocked on second data source + assertNull(operator.getOutput()); + taskBuffers.getUnchecked(TASK_2_ID).addPages(task2Pages, true); + assertOperatorIsUnblocked(operator); + + ImmutableList outputTypes = ImmutableList.of(INTEGER, BIGINT); + Page expected = rowPagesBuilder(outputTypes).row(null, 0).row(5, null).row(5, 2).row(4, 1).row(3, 2).row(3, 4) + .build().get(0); + + assertPageEquals(outputTypes, getOnlyElement(pullAvailablePages(operator)), expected); + operator.close(); + } + + @Test + public void testMultipleStreamsSameOutputColumns() throws Exception + { + List types = ImmutableList.of(BIGINT, BIGINT, BIGINT); + + MergeOmniOperator operator = createMergeOmniOperator(types, ImmutableList.of(0, 1, 2), ImmutableList.of(0), + ImmutableList.of(ASC_NULLS_FIRST)); + operator.addSplit(createRemoteSplit(TASK_1_ID)); + operator.addSplit(createRemoteSplit(TASK_2_ID)); + operator.addSplit(createRemoteSplit(TASK_3_ID)); + operator.noMoreSplits(); + + List source1Pages = rowPagesBuilder(types).row(1, 1, 2).row(8, 1, 1).row(19, 1, 3).row(27, 1, 4) + .row(41, 2, 5).pageBreak().row(55, 1, 2).row(89, 1, 3).row(101, 1, 4).row(202, 1, 3).row(399, 2, 2) + .pageBreak().row(400, 1, 1).row(401, 1, 7).row(402, 1, 6).build(); + + List source2Pages = rowPagesBuilder(types).row(2, 1, 2).row(8, 1, 1).row(19, 1, 3).row(25, 1, 4) + .row(26, 2, 5).pageBreak().row(56, 1, 2).row(66, 1, 3).row(77, 1, 4).row(88, 1, 3).row(99, 2, 2) + .pageBreak().row(99, 1, 1).row(100, 1, 7).row(100, 1, 6).build(); + + List source3Pages = rowPagesBuilder(types).row(88, 1, 3).row(89, 1, 3).row(90, 1, 3).row(91, 1, 4) + .row(92, 2, 5).pageBreak().row(93, 1, 2).row(94, 1, 3).row(95, 1, 4).row(97, 1, 3).row(98, 2, 2) + .build(); + + // blocked on first data source + assertNull(operator.getOutput()); + assertFalse(operator.isFinished()); + assertOperatorIsBlocked(operator); + taskBuffers.getUnchecked(TASK_1_ID).addPage(source1Pages.get(0), false); + assertOperatorIsUnblocked(operator); + + // blocked on second data source + assertNull(operator.getOutput()); + assertOperatorIsBlocked(operator); + taskBuffers.getUnchecked(TASK_2_ID).addPage(source2Pages.get(0), false); + + // blocked on third data source + assertNull(operator.getOutput()); + assertOperatorIsBlocked(operator); + taskBuffers.getUnchecked(TASK_3_ID).addPage(source3Pages.get(0), false); + + taskBuffers.getUnchecked(TASK_1_ID).addPage(source1Pages.get(1), false); + taskBuffers.getUnchecked(TASK_2_ID).addPage(source2Pages.get(1), false); + taskBuffers.getUnchecked(TASK_3_ID).addPage(source3Pages.get(1), true); + + taskBuffers.getUnchecked(TASK_2_ID).addPage(source2Pages.get(2), true); + taskBuffers.getUnchecked(TASK_1_ID).addPage(source1Pages.get(2), true); + + Page expected = rowPagesBuilder(types).row(1, 1, 2).row(2, 1, 2).row(8, 1, 1).row(8, 1, 1).row(19, 1, 3) + .row(19, 1, 3).row(25, 1, 4).row(26, 2, 5).row(27, 1, 4).row(41, 2, 5).row(55, 1, 2).row(56, 1, 2) + .row(66, 1, 3).row(77, 1, 4).row(88, 1, 3).row(88, 1, 3).row(89, 1, 3).row(89, 1, 3).row(90, 1, 3) + .row(91, 1, 4).row(92, 2, 5).row(93, 1, 2).row(94, 1, 3).row(95, 1, 4).row(97, 1, 3).row(98, 2, 2) + .row(99, 2, 2).row(99, 1, 1).row(100, 1, 6).row(100, 1, 7).row(101, 1, 4).row(202, 1, 3).row(399, 2, 2) + .row(400, 1, 1).row(401, 1, 7).row(402, 1, 6).build().get(0); + + assertPageEquals(types, getOnlyElement(pullAvailablePages(operator)), expected); + operator.close(); + } + + private MergeOmniOperator createMergeOmniOperator(List sourceTypes, List outputChannels, + List sortChannels, List sortOrder) + { + int mergeOperatorId = operatorId.getAndIncrement(); + MergeOmniOperator.MergeOmniOperatorFactory factory = new MergeOmniOperator.MergeOmniOperatorFactory( + mergeOperatorId, mergeOperatorId, new PlanNodeId("plan_node_id" + mergeOperatorId), + exchangeClientFactory, serdeFactory, orderingCompiler, sourceTypes, outputChannels, sortChannels, + sortOrder); + + DriverContext driverContext = createTaskContext(executor, executor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + return (MergeOmniOperator) factory.createOperator(driverContext); + } + + private static Split createRemoteSplit(String taskId) + { + return new Split(ExchangeOperator.REMOTE_CONNECTOR_ID, + new RemoteSplit(URI.create("http://localhost/" + taskId), "new split test instance id"), + Lifespan.taskWide()); + } + + private static List pullAvailablePages(Operator operator) throws InterruptedException + { + long endTime = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + List outputPages = new ArrayList<>(); + + assertOperatorIsUnblocked(operator); + + while (!operator.isFinished() && System.nanoTime() - endTime < 0) { + assertFalse(operator.needsInput()); + Page outputPage = operator.getOutput(); + if (outputPage != null) { + outputPages.add(outputPage); + } + else { + Thread.sleep(10); + } + } + + // verify state + assertFalse(operator.needsInput(), "Operator still wants input"); + assertTrue(operator.isFinished(), "Expected operator to be finished"); + + return outputPages; + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestMergePagesOmni.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestMergePagesOmni.java new file mode 100644 index 0000000000000000000000000000000000000000..54e03556aaed23928bed3cc086c06ad12c2e4fbb --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestMergePagesOmni.java @@ -0,0 +1,154 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import io.prestosql.operator.WorkProcessor; +import io.prestosql.spi.Page; +import io.prestosql.spi.type.Type; +import nova.hetu.olk.operator.filterandproject.OmniMergePages; +import org.testng.annotations.Test; + +import java.util.List; + +import static io.prestosql.SequencePageBuilder.createSequencePage; +import static io.prestosql.execution.buffer.PageSplitterUtil.splitPage; +import static io.prestosql.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; +import static io.prestosql.operator.PageAssertions.assertPageEquals; +import static io.prestosql.operator.WorkProcessorAssertion.assertFinishes; +import static io.prestosql.operator.WorkProcessorAssertion.validateResult; +import static io.prestosql.operator.project.MergePages.mergePages; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.DateType.DATE; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static java.lang.Math.toIntExact; +import static org.testng.Assert.assertEquals; + +public class TestMergePagesOmni +{ + private static final List TYPES = ImmutableList.of(DATE, DOUBLE, BIGINT); + + @Test + public void testMinPageSizeThreshold() + { + Page page = createSequencePage(TYPES, 10); + + WorkProcessor mergePages = mergePages(TYPES, page.getSizeInBytes(), Integer.MAX_VALUE, Integer.MAX_VALUE, + pagesSource(page), newSimpleAggregatedMemoryContext()); + + validateResult(mergePages, actualPage -> assertPageEquals(TYPES, actualPage, page)); + assertFinishes(mergePages); + } + + @Test + public void testMinRowCountThreshold() + { + Page page = createSequencePage(TYPES, 20); + + WorkProcessor mergePages = mergePages(TYPES, 1024 * 1024, page.getPositionCount(), Integer.MAX_VALUE, + pagesSource(page), newSimpleAggregatedMemoryContext()); + + validateResult(mergePages, actualPage -> assertPageEquals(TYPES, actualPage, page)); + assertFinishes(mergePages); + } + + @Test + public void testBufferSmallPages() + { + int singlePageRowCount = 10; + Page page = createSequencePage(TYPES, singlePageRowCount * 2); + List splits = splitPage(page, page.getSizeInBytes() / 2); + + WorkProcessor mergePages = mergePages(TYPES, page.getSizeInBytes() + 1, page.getPositionCount() + 1, + Integer.MAX_VALUE, pagesSource(splits.get(0), splits.get(1)), newSimpleAggregatedMemoryContext()); + + validateResult(mergePages, actualPage -> assertPageEquals(TYPES, actualPage, page)); + assertFinishes(mergePages); + } + + @Test + public void testFlushOnBigPage() + { + Page smallPage = createSequencePage(TYPES, 20); + Page bigPage = createSequencePage(TYPES, 100); + + WorkProcessor mergePages = mergePages(TYPES, bigPage.getSizeInBytes(), bigPage.getPositionCount(), + Integer.MAX_VALUE, pagesSource(smallPage, bigPage), newSimpleAggregatedMemoryContext()); + + validateResult(mergePages, actualPage -> assertPageEquals(TYPES, actualPage, smallPage)); + validateResult(mergePages, actualPage -> assertPageEquals(TYPES, actualPage, bigPage)); + assertFinishes(mergePages); + } + + @Test + public void testFlushOnFullPage() + { + int singlePageRowCount = 50; + List types = ImmutableList.of(BIGINT); + Page page = createSequencePage(types, singlePageRowCount * 2); + List splits = splitPage(page, page.getSizeInBytes() / 2); + + WorkProcessor mergePages = mergePages(types, page.getSizeInBytes() / 2 + 1, + page.getPositionCount() / 2 + 1, toIntExact(page.getSizeInBytes()), + pagesSource(splits.get(0), splits.get(1), splits.get(0), splits.get(1)), + newSimpleAggregatedMemoryContext()); + + validateResult(mergePages, actualPage -> assertPageEquals(types, actualPage, page)); + validateResult(mergePages, actualPage -> assertPageEquals(types, actualPage, page)); + assertFinishes(mergePages); + } + + @Test + public void testMergeOnLargePage() + { + int singlePageRowCount = 10; + Page page = createSequencePage(TYPES, singlePageRowCount * 3); + List splits = splitPage(page, page.getSizeInBytes() / 3); + System.out.println("Updated"); + WorkProcessor mergePages = mergePages(TYPES, page.getSizeInBytes() / 3 + 1, + page.getPositionCount() / 3 + 1, pagesSource(splits), newSimpleAggregatedMemoryContext()); + + validateResult(mergePages, actualPage -> assertPageEquals(TYPES, actualPage, page)); + assertFinishes(mergePages); + } + + @Test + public void testPageToVector() + { + Page page = createSequencePage(TYPES, 10); + Page page1 = createSequencePage(TYPES, 20); + Page page2 = createSequencePage(TYPES, 30); + + OmniMergePages.OmniMergePagesTransformation omni = new OmniMergePages.OmniMergePagesTransformation(TYPES, 5000, + 1024, 1024 * 1024, newSimpleAggregatedMemoryContext() + .newLocalMemoryContext(OmniMergePages.OmniMergePagesTransformation.class.getSimpleName())); + omni.appendPage(page); + omni.appendPage(page1); + omni.appendPage(page2); + Page actualPage = omni.flush(); + assertEquals(actualPage.getPositionCount(), 60); + } + + private static WorkProcessor pagesSource(Page... pages) + { + return WorkProcessor.fromIterable(ImmutableList.copyOf(pages)); + } + + private static WorkProcessor pagesSource(List pages) + { + return WorkProcessor.fromIterable(ImmutableList.copyOf(pages)); + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestOmniExpressionUtil.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestOmniExpressionUtil.java new file mode 100644 index 0000000000000000000000000000000000000000..79052cb6c4e47d07530e1e0df0637c5a8db0f2fb --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestOmniExpressionUtil.java @@ -0,0 +1,417 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import io.airlift.slice.Slice; +import io.prestosql.spi.connector.QualifiedObjectName; +import io.prestosql.spi.function.BuiltInFunctionHandle; +import io.prestosql.spi.function.OperatorType; +import io.prestosql.spi.function.Signature; +import io.prestosql.spi.relation.CallExpression; +import io.prestosql.spi.relation.ConstantExpression; +import io.prestosql.spi.relation.InputReferenceExpression; +import io.prestosql.spi.relation.RowExpression; +import io.prestosql.spi.relation.SpecialForm; +import io.prestosql.spi.relation.VariableReferenceExpression; +import io.prestosql.spi.type.DecimalType; +import io.prestosql.spi.type.StandardTypes; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.tree.BetweenPredicate; +import io.prestosql.sql.tree.Cast; +import io.prestosql.sql.tree.ComparisonExpression; +import io.prestosql.sql.tree.DecimalLiteral; +import io.prestosql.sql.tree.DoubleLiteral; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.LikePredicate; +import io.prestosql.sql.tree.LogicalBinaryExpression; +import io.prestosql.sql.tree.SearchedCaseExpression; +import io.prestosql.sql.tree.StringLiteral; +import io.prestosql.sql.tree.SymbolReference; +import io.prestosql.sql.tree.WhenClause; +import io.prestosql.testing.assertions.Assert; +import nova.hetu.olk.operator.filterandproject.OmniRowExpressionUtil; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; +import static io.prestosql.spi.function.FunctionKind.SCALAR; +import static io.prestosql.spi.function.Signature.internalOperator; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.DecimalType.createDecimalType; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.TypeSignature.parseTypeSignature; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static io.prestosql.sql.relational.Expressions.constant; +import static io.prestosql.sql.relational.Expressions.field; + +/** + * The OmniExpressionUtil tests + */ +@Test(singleThreaded = true) +public class TestOmniExpressionUtil +{ + private final InputReferenceExpression shipDate = field(0, BIGINT); + private final ConstantExpression condition1 = constant(10000, BIGINT); + private final ConstantExpression condition2 = constant(10471, BIGINT); + private final InputReferenceExpression extendedPrice = field(1, BIGINT); + private final InputReferenceExpression extendedDecimalPrice = field(1, createDecimalType()); + private final InputReferenceExpression discount = field(2, BIGINT); + private final ConstantExpression bool1 = constant(true, BOOLEAN); + private final ConstantExpression bool2 = constant(false, BOOLEAN); + private final String string1 = "%hello%world%"; + private final Slice slice1 = utf8Slice("%hello%world%"); + + private final List cmpOps = Arrays.asList(OperatorType.LESS_THAN_OR_EQUAL, OperatorType.LESS_THAN, + OperatorType.GREATER_THAN, OperatorType.GREATER_THAN_OR_EQUAL, OperatorType.EQUAL, OperatorType.NOT_EQUAL); + private final List arithOps = Arrays.asList(OperatorType.ADD, OperatorType.SUBTRACT, + OperatorType.MULTIPLY, OperatorType.DIVIDE, OperatorType.MODULUS); + private final OperatorType unaryOp = OperatorType.NEGATION; + + private final Signature likeSignature = new Signature(QualifiedObjectName.valueOfDefaultFunction("LIKE"), SCALAR, + parseTypeSignature(StandardTypes.BOOLEAN)); + private final List whenThenList = new ArrayList( + Arrays.asList( + new WhenClause( + new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, + new SymbolReference("avg"), new DoubleLiteral("0.0")), + new Cast(new DecimalLiteral("10.00"), "decimal(12,2)")), + new WhenClause(new LikePredicate(new SymbolReference("testColumn"), new StringLiteral("%a_%"), + Optional.empty()), new Cast(new DecimalLiteral("15.00"), "decimal(12,2)")))); + + @DataProvider(name = "binaryComparisonExpression") + private Object[][] prepareBinCmpTests() + { + Object[][] testCase = new Object[cmpOps.size()][]; + for (int i = 0; i < cmpOps.size(); i++) { + testCase[i] = new Object[]{ + call(internalOperator(cmpOps.get(i), BOOLEAN.getTypeSignature(), INTEGER.getTypeSignature(), + INTEGER.getTypeSignature()), BOOLEAN, shipDate, condition1), + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"" + cmpOps.get(i) + + "\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":" + + "2,\"colVal\":0},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false," + + "\"value\":10000}}"}; + } + return testCase; + } + + @DataProvider(name = "binaryArithmeticExpression") + private Object[][] prepareBinArithTests() + { + Object[][] testCase = new Object[arithOps.size()][]; + for (int i = 0; i < arithOps.size(); i++) { + testCase[i] = new Object[]{ + call(internalOperator(arithOps.get(i), BIGINT.getTypeSignature(), BIGINT.getTypeSignature(), + BIGINT.getTypeSignature()), BIGINT, extendedPrice, discount), + "{\"exprType\":\"BINARY\",\"returnType\":2,\"operator\":\"" + arithOps.get(i) + + "\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":" + + "2,\"colVal\":1},\"right\":{\"exprType\":\"FIELD_REFERENCE\"," + + "\"dataType\":2,\"colVal\":2}}"}; + } + return testCase; + } + + @DataProvider(name = "binaryDecimalArithmeticExpression") + private Object[][] prepareBinDecimalArithTests() + { + Object[][] testCase = new Object[(arithOps.size() - 1) * 2][]; + int decimalTypesNum = 2; + for (int k = 0; k < decimalTypesNum; k++) { + int decimalPrecision = 16 * (k + 1); + DecimalType decimalType = createDecimalType(decimalPrecision, 2); + InputReferenceExpression decimalPrice = field(1, decimalType); + InputReferenceExpression decimalDiscount = field(2, decimalType); + int dataType = 6; + if (k == 1) { + dataType = 7; + } + for (int i = 0; i < arithOps.size() - 1; i++) { + testCase[((arithOps.size() - 1) * k) + i] = new Object[]{ + call(internalOperator(arithOps.get(i), decimalType.getTypeSignature(), + decimalType.getTypeSignature(), decimalType.getTypeSignature()), decimalType, + decimalPrice, decimalDiscount), + "{\"exprType\":\"BINARY\",\"returnType\":" + dataType + ",\"operator\":\"" + arithOps.get(i) + + "\",\"precision\":" + decimalPrecision + ",\"scale\":2," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":" + dataType + + ",\"colVal\":1," + "\"precision\":" + decimalPrecision + + ",\"scale\":2},\"right\":{\"exprType\":\"FIELD_REFERENCE\"," + "\"dataType\":" + + dataType + ",\"colVal\":2,\"precision\":" + decimalPrecision + ",\"scale\":2}}"}; + } + } + return testCase; + } + + @DataProvider(name = "constantExpression") + private Object[][] prepareConstantExpTests() + { + ByteBuffer stringBuffer = ByteBuffer.wrap("stringVal".getBytes()); + Slice varcharSlice = wrappedBuffer(stringBuffer); + + return new Object[][]{ + {constant(true, BOOLEAN), "{\"exprType\":\"LITERAL\",\"dataType\":4,\"isNull\":false,\"value\":true}"}, + {constant(12345678910L, BIGINT), + "{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false,\"value\":12345678910}"}, + {constant(1, INTEGER), "{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":1}"}, + {constant(2.0, DOUBLE), "{\"exprType\":\"LITERAL\",\"dataType\":3,\"isNull\":false,\"value\":2.0}"}, + {constant(4, createDecimalType(19)), + "{\"exprType\":\"LITERAL\",\"dataType\":7,\"isNull\":false,\"value\":\"4\",\"precision\":19,\"scale\":0}"}, + {constant(5, createDecimalType(37)), + "{\"exprType\":\"LITERAL\",\"dataType\":7,\"isNull\":false,\"value\":\"5\",\"precision\":37,\"scale\":0}"}, + {constant(varcharSlice, VARCHAR), + "{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"stringVal\",\"width\":1048576}"}, + // Need support of UNKNOWN presto type in DataType + // {constant("UNKNOWN", UNKNOWN), + // "{\"exprType\":\"LITERAL\",\"dataType\":0,\"isNull\":false,\"value\":UNKNOWN}"}, + {constant(null, BIGINT), "{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":true}"}}; + } + + @DataProvider(name = "sqlRegexConversion") + private Object[][] prepareSqlRegexTest() + { + return new Object[][]{{"%hello%world%", utf8Slice("^.*hello.*world.*$")}, {"a_%", utf8Slice("^a..*$")}, + {"5*5", utf8Slice("^5\\*5$")}, {"%$500.00^2.0", utf8Slice("^.*\\$500\\.00\\^2\\.0$")}}; + } + + @DataProvider(name = "omniFilterConversion") + private Object[][] prepareOmniFilterTest() + { + return new Object[][]{{(Expression) new LogicalBinaryExpression(LogicalBinaryExpression.Operator.AND, + new LikePredicate(new SymbolReference("testColumn"), new StringLiteral("%a_%"), Optional.empty()), + new LikePredicate(new SymbolReference("testColumn"), new StringLiteral("%p_%"), Optional.empty())), + (RowExpression) new SpecialForm(SpecialForm.Form.AND, BOOLEAN, + call(likeSignature, BOOLEAN, field(1, VARCHAR), constant(utf8Slice("%a_%"), VARCHAR)), + call(likeSignature, BOOLEAN, field(2, VARCHAR), constant(utf8Slice("%p_%"), VARCHAR))), + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\",\"left\":{\"exprType\":\"FUNCTION\",\"returnType\":4,\"function_name\":\"LIKE\",\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":1048576},{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"^.*a..*$\",\"width\":1048576}]},\"right\":{\"exprType\":\"FUNCTION\",\"returnType\":4,\"function_name\":\"LIKE\",\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":1048576},{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"^.*p..*$\",\"width\":1048576}]}}"}, + {(Expression) new LogicalBinaryExpression(LogicalBinaryExpression.Operator.AND, + new BetweenPredicate(new SymbolReference("desc"), + new Cast(new DecimalLiteral("76.00"), "decimal(4,2)"), + new Cast(new DecimalLiteral("106.00"), "decimal(5,2)")), + new LikePredicate(new SymbolReference("testColumn"), + new Cast(new StringLiteral("%a%"), "varchar"), Optional.empty())), + (RowExpression) new SpecialForm(SpecialForm.Form.AND, BOOLEAN, + new SpecialForm(SpecialForm.Form.BETWEEN, BOOLEAN, field(3, createDecimalType(7, 2)), + constant(7600, createDecimalType(4, 0)), + constant(10600, createDecimalType(5, 0))), + call(likeSignature, BOOLEAN, field(2, VARCHAR), constant(utf8Slice("%p_%"), VARCHAR))), + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\",\"left\":{\"exprType\":\"BETWEEN\",\"returnType\":4,\"value\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":6,\"colVal\":3,\"precision\":7,\"scale\":2},\"lower_bound\":{\"exprType\":\"LITERAL\",\"dataType\":6,\"isNull\":false,\"value\":7600,\"precision\":4,\"scale\":0},\"upper_bound\":{\"exprType\":\"LITERAL\",\"dataType\":6,\"isNull\":false,\"value\":10600,\"precision\":5,\"scale\":0}},\"right\":{\"exprType\":\"FUNCTION\",\"returnType\":4,\"function_name\":\"LIKE\",\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":1048576},{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"^.*a.*$\",\"width\":1048576}]}}"}, + {(Expression) new LogicalBinaryExpression(LogicalBinaryExpression.Operator.AND, + new LikePredicate(new SymbolReference("testColumn"), + new Cast(new StringLiteral("%p%"), "varchar"), Optional.empty()), + new LikePredicate(new SymbolReference("testColumn"), + new Cast(new StringLiteral("%a%"), "varchar"), Optional.empty())), + (RowExpression) new SpecialForm(SpecialForm.Form.AND, BOOLEAN, + call(likeSignature, BOOLEAN, field(2, VARCHAR), constant(utf8Slice("%p%"), VARCHAR)), + call(likeSignature, BOOLEAN, field(2, VARCHAR), constant(utf8Slice("%a%"), VARCHAR))), + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\",\"left\":{\"exprType\":\"FUNCTION\",\"returnType\":4,\"function_name\":\"LIKE\",\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":1048576},{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"^.*p.*$\",\"width\":1048576}]},\"right\":{\"exprType\":\"FUNCTION\",\"returnType\":4,\"function_name\":\"LIKE\",\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":1048576},{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"^.*a.*$\",\"width\":1048576}]}}"}, + {(Expression) new SearchedCaseExpression(whenThenList, + Optional.of(new Cast(new DecimalLiteral("0.05"), "decimal(12,2)"))), + (RowExpression) new SpecialForm(SpecialForm.Form.IF, createDecimalType(12, 2), + call(new Signature(QualifiedObjectName.valueOfDefaultFunction("$operator$GREATER_THAN"), + SCALAR, parseTypeSignature(StandardTypes.BOOLEAN)), BOOLEAN, field(2, DOUBLE), + constant(0.0, DOUBLE)), + constant(1000, createDecimalType(12, 2)), + new SpecialForm(SpecialForm.Form.IF, createDecimalType(12, 2), + call(likeSignature, BOOLEAN, field(1, VARCHAR), + constant(utf8Slice("%a_%"), VARCHAR)), + constant(1500, createDecimalType(12, 2)), + constant(5, createDecimalType(12, 2)))), + "{\"exprType\":\"IF\",\"returnType\":6,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"GREATER_THAN\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":2},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":3,\"isNull\":false,\"value\":0.0}},\"if_true\":{\"exprType\":\"LITERAL\",\"dataType\":6,\"isNull\":false,\"value\":1000,\"precision\":12,\"scale\":2},\"if_false\":{\"exprType\":\"IF\",\"returnType\":6,\"condition\":{\"exprType\":\"FUNCTION\",\"returnType\":4,\"function_name\":\"LIKE\",\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":1048576},{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"^.*a..*$\",\"width\":1048576}]},\"if_true\":{\"exprType\":\"LITERAL\",\"dataType\":6,\"isNull\":false,\"value\":1500,\"precision\":12,\"scale\":2},\"if_false\":{\"exprType\":\"LITERAL\",\"dataType\":6,\"isNull\":false,\"value\":5,\"precision\":12,\"scale\":2}}}"}}; + } + + /** + * Test ConstantExpression + * + * @param literal ConstantExpression + * @param expected String + */ + @Test(dataProvider = "constantExpression") + public void testConstantExpression(ConstantExpression literal, String expected) + { + String parseRes = OmniRowExpressionUtil.expressionStringify(literal, OmniRowExpressionUtil.Format.JSON); + Assert.assertEquals(parseRes, expected, literal.getType().getDisplayName() + " parsed result doesn't match"); + } + + /** + * Test InputReferenceExpression + */ + @Test + public void testInputReferenceExpression() + { + String referenceExpected = "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":1}"; + String parseRes = OmniRowExpressionUtil.expressionStringify(extendedPrice, OmniRowExpressionUtil.Format.JSON); + Assert.assertEquals(parseRes, referenceExpected, "InputReference parsed result doesn't match"); + } + + /** + * Test DecimalInputReferenceExpression + */ + @Test + public void testDecimnalInputReferenceExpression() + { + String referenceExpected = "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":7,\"colVal\":1," + + "\"precision\":38,\"scale\":0}"; + String parseRes = OmniRowExpressionUtil.expressionStringify(extendedDecimalPrice, + OmniRowExpressionUtil.Format.JSON); + Assert.assertEquals(parseRes, referenceExpected, "InputReference parsed result doesn't match"); + } + + /** + * Test VariableReferenceExpression + */ + @Test + public void testVariableReferenceExpression() + { + VariableReferenceExpression variableReference = new VariableReferenceExpression("testVariable", INTEGER); + String variableRefExpected = "{\"exprType\":\"VARIABLE_REFERENCE\",\"dataType\":1,\"varName\":\"testVariable\"}"; + String parseRes = OmniRowExpressionUtil.expressionStringify(variableReference, + OmniRowExpressionUtil.Format.JSON); + Assert.assertEquals(parseRes, variableRefExpected, "variableReference parsed result doesn't match"); + } + + /** + * Test LikeExpressionConversion + */ + @Test(dataProvider = "sqlRegexConversion") + public void testLikeExpressionConversion(String query, Slice expected) + { + CallExpression testExpr = (CallExpression) call( + new Signature(QualifiedObjectName.valueOfDefaultFunction("LIKE"), SCALAR, + parseTypeSignature(StandardTypes.BOOLEAN)), + BOOLEAN, field(1, VARCHAR), constant(utf8Slice(query), VARCHAR)); + Optional testTranslatedExpr = Optional.of(testExpr); + Optional testLikeExpr = OmniRowExpressionUtil.generateLikeExpr(query, testTranslatedExpr); + Assert.assertEquals( + ((ConstantExpression) ((CallExpression) testLikeExpr.get()).getArguments().get(1)).getValue(), expected, + "SQL regex conversion result doesn't match"); + } + + /** + * Test OmniFilterExpressionConversion + */ + @Test(dataProvider = "omniFilterConversion") + public void testOmniFilterConversion(Expression staticExpr, RowExpression translatedExpr, String expected) + { + Optional parseExpr = OmniRowExpressionUtil.generateOmniExpr(staticExpr, translatedExpr); + String parseRes = OmniRowExpressionUtil.expressionStringify(parseExpr.get(), OmniRowExpressionUtil.Format.JSON); + Assert.assertEquals(parseRes, expected, "OmniFilterExpression parsed result does not match"); + } + + /** + * Test binary comparison CallExpression + * + * @param call CallExpression + * @param expected String + */ + @Test(dataProvider = "binaryComparisonExpression") + public void testCmpBinOps(CallExpression call, String expected) + { + String parseRes = OmniRowExpressionUtil.expressionStringify(call, OmniRowExpressionUtil.Format.JSON); + Assert.assertEquals(parseRes, expected, call.getDisplayName() + " parsed result doesn't match"); + } + + /** + * Test binary arithmetic CallExpression + * + * @param call CallExpression + * @param expected String + */ + @Test(dataProvider = "binaryArithmeticExpression") + public void testArithBinOps(CallExpression call, String expected) + { + String parseRes = OmniRowExpressionUtil.expressionStringify(call, OmniRowExpressionUtil.Format.JSON); + Assert.assertEquals(parseRes, expected, call.getDisplayName() + " parsed result doesn't match"); + } + + /** + * Test binary arithmetic CallExpression + * + * @param call CallExpression + * @param expected String + */ + @Test(dataProvider = "binaryDecimalArithmeticExpression") + public void testArithDecBinOps(CallExpression call, String expected) + { + String parseRes = OmniRowExpressionUtil.expressionStringify(call, OmniRowExpressionUtil.Format.JSON); + Assert.assertEquals(parseRes, expected, call.getDisplayName() + " parsed result doesn't match"); + } + + /** + * Test unary CallExpression + */ + @Test + public void testUnaryOps() + { + CallExpression testNegation = (CallExpression) call( + internalOperator(unaryOp, BOOLEAN.getTypeSignature(), BOOLEAN.getTypeSignature()), BOOLEAN, + constant(false, BOOLEAN)); + String unaryOpExpected = "{\"exprType\":\"UNARY\",\"returnType\":4,\"operator\":\"NEGATION\",\"expr\":{\"exprType\":\"LITERAL\",\"dataType\":4,\"isNull\":false,\"value\":false}}"; + String parseRes = OmniRowExpressionUtil.expressionStringify(testNegation, OmniRowExpressionUtil.Format.JSON); + Assert.assertEquals(parseRes, unaryOpExpected, "NEGATION parsed result doesn't match"); + } + + @DataProvider(name = "specialForm") + private Object[][] specialForm() + { + return new Object[][]{{new SpecialForm(SpecialForm.Form.AND, BOOLEAN, bool1, bool2), + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\",\"left\":{\"exprType\":\"LITERAL\",\"dataType\":4,\"isNull\":false,\"value\":true},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":4,\"isNull\":false,\"value\":false}}"}, + {new SpecialForm(SpecialForm.Form.OR, BOOLEAN, bool1, bool2), + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\",\"left\":{\"exprType\":\"LITERAL\",\"dataType\":4,\"isNull\":false,\"value\":true},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":4,\"isNull\":false,\"value\":false}}"}, + {new SpecialForm(SpecialForm.Form.BETWEEN, BOOLEAN, shipDate, condition1, condition2), + "{\"exprType\":\"BETWEEN\",\"returnType\":4,\"value\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":0},\"lower_bound\":{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false,\"value\":10000},\"upper_bound\":{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false,\"value\":10471}}"}, + {new SpecialForm(SpecialForm.Form.IF, BIGINT, bool1, condition1, condition2), + "{\"exprType\":\"IF\",\"returnType\":2,\"condition\":{\"exprType\":\"LITERAL\",\"dataType\":4,\"isNull\":false,\"value\":true},\"if_true\":{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false,\"value\":10000},\"if_false\":{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false,\"value\":10471}}"}, + {new SpecialForm(SpecialForm.Form.COALESCE, BIGINT, condition1, condition2), + "{\"exprType\":\"COALESCE\",\"returnType\":2,\"value1\":{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false,\"value\":10000},\"value2\":{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false,\"value\":10471}}"}}; + } + + /** + * Test SpecialForm + * + * @param specialForm SpecialForm + * @param expected String + */ + @Test(dataProvider = "specialForm") + public void testSpecialForm(SpecialForm specialForm, String expected) + { + String parseRes = OmniRowExpressionUtil.expressionStringify(specialForm, OmniRowExpressionUtil.Format.JSON); + Assert.assertEquals(parseRes, expected, specialForm.getForm().toString() + " parsed result doesn't match"); + } + + /** + * Test Lambda Expression + */ + @Test + public void testLambdaDefinitionExpression() + { + // TODO: Add tests when implemented + } + + private static RowExpression call(Signature signature, Type type, RowExpression... arguments) + { + BuiltInFunctionHandle functionHandle = new BuiltInFunctionHandle(signature); + CallExpression expression = new CallExpression(signature.getName().getObjectName(), functionHandle, type, + Arrays.asList(arguments)); + return expression; + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestOmniMergingPageOutput.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestOmniMergingPageOutput.java new file mode 100644 index 0000000000000000000000000000000000000000..9caca035f88da5c41b19c774ef9c6665fd630047 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestOmniMergingPageOutput.java @@ -0,0 +1,102 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import io.prestosql.spi.Page; +import io.prestosql.spi.type.RowType; +import io.prestosql.spi.type.Type; +import nova.hetu.olk.operator.filterandproject.OmniMergingPageOutput; +import nova.hetu.olk.tool.BlockUtils; +import nova.hetu.olk.tool.OperatorUtils; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import static com.google.common.collect.Iterators.transform; +import static io.prestosql.RowPagesBuilder.rowPagesBuilder; +import static io.prestosql.operator.PageAssertions.assertPageEquals; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + +public class TestOmniMergingPageOutput +{ + @Test + public void testRowBlock() + { + List types = ImmutableList.of(BIGINT, DOUBLE); + List fields = new ArrayList<>(); + for (int i = 0; i < types.size(); i++) { + fields.add(new RowType.Field(Optional.of(i + ""), types.get(i))); + } + RowType rowType = RowType.from(fields); + + OmniMergingPageOutput operator = createOmniMergingPageOutput(ImmutableList.of(rowType), 1024, 10); + assertFalse(operator.isFinished()); + List values1 = new ArrayList<>(); + values1.add(1); + values1.add(1.1); + List values11 = new ArrayList<>(); + values11.add(2); + values11.add(2.2); + + List values2 = new ArrayList<>(); + values2.add(3); + values2.add(3.3); + List values22 = new ArrayList<>(); + values22.add(4); + values22.add(4.4); + List input = rowPagesBuilder(rowType).row(values1).row(values11).pageBreak().row(values2).row(values22) + .build(); + // transfer on-heap page to off-heap + List offHeapPages = input.stream().map(var -> + OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, var, + Arrays.stream(var.getBlocks()).map(val -> rowType).collect(Collectors.toList()))) + .collect(Collectors.toList()); + operator.addInput(createPagesIterator(offHeapPages)); + + assertNull(operator.getOutput()); + assertFalse(operator.isFinished()); + operator.finish(); + assertFalse(operator.isFinished()); + + Page expected = rowPagesBuilder(rowType).row(values1).row(values11).row(values2).row(values22).build().get(0); + Page result = operator.getOutput(); + assertPageEquals(ImmutableList.of(rowType), result, expected); + assertTrue(operator.isFinished()); + BlockUtils.freePage(result); + } + + private OmniMergingPageOutput createOmniMergingPageOutput(List sourceTypes, long minPageSizeInBytes, + int minRowCount) + { + return new OmniMergingPageOutput(sourceTypes, minPageSizeInBytes, minRowCount); + } + + private Iterator> createPagesIterator(List pages) + { + return transform(pages.iterator(), Optional::of); + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestOrderByOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestOrderByOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..376d94b3af4134b6f220e967989ad24db311fdef --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestOrderByOmniOperator.java @@ -0,0 +1,127 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import io.prestosql.operator.DriverContext; +import io.prestosql.spi.Page; +import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.testing.MaterializedResult; +import io.prestosql.testing.TestingTaskContext; +import nova.hetu.olk.tool.OperatorUtils; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; + +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.airlift.units.DataSize.succinctBytes; +import static io.prestosql.RowPagesBuilder.rowPagesBuilder; +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.operator.OperatorAssertion.assertOperatorEquals; +import static io.prestosql.operator.OperatorAssertion.toMaterializedResult; +import static io.prestosql.operator.OperatorAssertion.toPages; +import static io.prestosql.spi.block.SortOrder.ASC_NULLS_LAST; +import static io.prestosql.spi.block.SortOrder.DESC_NULLS_LAST; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.testing.MaterializedResult.resultBuilder; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static nova.hetu.olk.operator.OrderByOmniOperator.OrderByOmniOperatorFactory.createOrderByOmniOperatorFactory; +import static org.testng.Assert.assertEquals; + +@Test(singleThreaded = true) +public class TestOrderByOmniOperator +{ + private ExecutorService executor; + private ScheduledExecutorService scheduledExecutor; + + @BeforeMethod + public void setUp() + { + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + } + + @AfterMethod(alwaysRun = true) + public void tearDown() + { + executor.shutdownNow(); + scheduledExecutor.shutdownNow(); + } + + @Test + public void testMultipleOutputPages() + { + // make operator produce multiple pages during finish phase + long start = System.nanoTime(); + int numberOfRows = 80_000; + List input = rowPagesBuilder(BIGINT, DOUBLE).addSequencePage(numberOfRows, 0, 0).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + OrderByOmniOperator.OrderByOmniOperatorFactory operatorFactory = createOrderByOmniOperatorFactory(0, + new PlanNodeId("test"), ImmutableList.of(BIGINT, DOUBLE), ImmutableList.of(1), ImmutableList.of(0, 1), + ImmutableList.of(DESC_NULLS_LAST, DESC_NULLS_LAST)); + + DriverContext driverContext = createDriverContext(0); + MaterializedResult.Builder expectedBuilder = resultBuilder(driverContext.getSession(), DOUBLE); + for (int i = 0; i < numberOfRows; ++i) { + expectedBuilder.row((double) numberOfRows - i - 1); + } + MaterializedResult expected = expectedBuilder.build(); + + List pages = toPages(operatorFactory, driverContext, offHeapPages, false); + + MaterializedResult actual = toMaterializedResult(driverContext.getSession(), expected.getTypes(), pages); + assertEquals(actual.getMaterializedRows(), expected.getMaterializedRows()); + long end = System.nanoTime(); + double cost = ((double) (end - start)) / 1000000.0; + System.out.println("testMultipleOutputPages elapsed time : " + cost + "ms"); + } + + @Test + public void testSingleFieldKey() + { + long start = System.nanoTime(); + List input = rowPagesBuilder(BIGINT, DOUBLE).row(1L, 0.1).row(2L, 0.2).pageBreak().row(-1L, -0.1) + .row(4L, 0.4).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + OrderByOmniOperator.OrderByOmniOperatorFactory operatorFactory = createOrderByOmniOperatorFactory(0, + new PlanNodeId("test"), ImmutableList.of(BIGINT, DOUBLE), ImmutableList.of(1), ImmutableList.of(0, 1), + ImmutableList.of(ASC_NULLS_LAST, ASC_NULLS_LAST)); + DriverContext driverContext = createDriverContext(0); + MaterializedResult expected = resultBuilder(driverContext.getSession(), DOUBLE).row(-0.1).row(0.1).row(0.2) + .row(0.4).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapPages, expected, false); + + long end = System.nanoTime(); + double cost = ((double) (end - start)) / 1000000.0; + System.out.println("testSingleFieldKey elapsed time : " + cost + "ms"); + } + + private DriverContext createDriverContext(long memoryLimit) + { + return TestingTaskContext.builder(executor, scheduledExecutor, TEST_SESSION) + .setMemoryPoolSize(succinctBytes(memoryLimit)).build().addPipelineContext(0, true, true, false) + .addDriverContext(); + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestScanFilterAndProjectOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestScanFilterAndProjectOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..9100491e09013e12968faab80dfe3de81d299525 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestScanFilterAndProjectOmniOperator.java @@ -0,0 +1,933 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import io.airlift.units.DataSize; +import io.hetu.core.statestore.hazelcast.HazelcastStateStoreBootstrapper; +import io.hetu.core.statestore.hazelcast.HazelcastStateStoreFactory; +import io.prestosql.SequencePageBuilder; +import io.prestosql.Session; +import io.prestosql.dynamicfilter.DynamicFilterCacheManager; +import io.prestosql.execution.Lifespan; +import io.prestosql.metadata.BoundVariables; +import io.prestosql.metadata.FunctionAndTypeManager; +import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.Split; +import io.prestosql.metadata.SqlScalarFunction; +import io.prestosql.operator.DriverContext; +import io.prestosql.operator.DummySpillerFactory; +import io.prestosql.operator.Operator; +import io.prestosql.operator.SourceOperator; +import io.prestosql.operator.WorkProcessorSourceOperatorAdapter; +import io.prestosql.operator.index.PageRecordSet; +import io.prestosql.operator.project.CursorProcessor; +import io.prestosql.operator.project.PageProcessor; +import io.prestosql.operator.scalar.AbstractTestFunctions; +import io.prestosql.operator.scalar.ApplyFunction; +import io.prestosql.seedstore.SeedStoreManager; +import io.prestosql.spi.Page; +import io.prestosql.spi.connector.CatalogName; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.connector.ConnectorPageSource; +import io.prestosql.spi.connector.FixedPageSource; +import io.prestosql.spi.connector.QualifiedObjectName; +import io.prestosql.spi.connector.RecordPageSource; +import io.prestosql.spi.connector.TestingColumnHandle; +import io.prestosql.spi.function.BuiltInFunctionHandle; +import io.prestosql.spi.function.BuiltInScalarFunctionImplementation; +import io.prestosql.spi.function.Signature; +import io.prestosql.spi.operator.ReuseExchangeOperator; +import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.plan.Symbol; +import io.prestosql.spi.plan.TableScanNode; +import io.prestosql.spi.relation.RowExpression; +import io.prestosql.spi.seedstore.Seed; +import io.prestosql.spi.seedstore.SeedStore; +import io.prestosql.spi.seedstore.SeedStoreSubType; +import io.prestosql.spi.statestore.StateStore; +import io.prestosql.spi.statestore.StateStoreBootstrapper; +import io.prestosql.spi.statestore.StateStoreFactory; +import io.prestosql.spi.type.StandardTypes; +import io.prestosql.spi.type.TimeZoneKey; +import io.prestosql.sql.gen.PageFunctionCompiler; +import io.prestosql.statestore.StateStoreProvider; +import io.prestosql.testing.MaterializedResult; +import io.prestosql.testing.TestingSession; +import io.prestosql.testing.TestingSplit; +import nova.hetu.olk.operator.filterandproject.OmniExpressionCompiler; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import java.lang.invoke.MethodHandle; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; +import java.util.function.LongUnaryOperator; +import java.util.function.Supplier; + +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.airlift.units.DataSize.Unit.BYTE; +import static io.airlift.units.DataSize.Unit.KILOBYTE; +import static io.hetu.core.statestore.hazelcast.HazelcastConstants.DISCOVERY_PORT_CONFIG_NAME; +import static io.prestosql.RowPagesBuilder.rowPagesBuilder; +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.SystemSessionProperties.ENABLE_CROSS_REGION_DYNAMIC_FILTER; +import static io.prestosql.block.BlockAssertions.toValues; +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.operator.OperatorAssertion.toMaterializedResult; +import static io.prestosql.operator.PageAssertions.assertPageEquals; +import static io.prestosql.operator.project.PageProcessor.MAX_BATCH_SIZE; +import static io.prestosql.spi.function.BuiltInScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty; +import static io.prestosql.spi.function.BuiltInScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL; +import static io.prestosql.spi.function.OperatorType.EQUAL; +import static io.prestosql.spi.function.Signature.internalScalarFunction; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.TypeSignature.parseTypeSignature; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static io.prestosql.spi.util.DateTimeZoneIndex.getDateTimeZone; +import static io.prestosql.spi.util.Reflection.methodHandle; +import static io.prestosql.sql.relational.Expressions.call; +import static io.prestosql.sql.relational.Expressions.constant; +import static io.prestosql.sql.relational.Expressions.field; +import static io.prestosql.testing.TestingHandles.TEST_TABLE_HANDLE; +import static io.prestosql.testing.TestingTaskContext.createTaskContext; +import static io.prestosql.testing.assertions.Assert.assertEquals; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static java.util.concurrent.TimeUnit.SECONDS; +import static nova.hetu.olk.tool.OperatorUtils.transferToOffHeapPages; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + +public class TestScanFilterAndProjectOmniOperator + extends AbstractTestFunctions +{ + private final Metadata metadata = createTestMetadataManager(); + private final OmniExpressionCompiler omniExpressionCompiler = new OmniExpressionCompiler(metadata, + new PageFunctionCompiler(metadata, 0)); + private ExecutorService executor; + private ScheduledExecutorService scheduledExecutor; + private StateStoreProvider stateStoreProvider; + private static final TimeZoneKey TIME_ZONE_KEY = TestingSession.DEFAULT_TIME_ZONE_KEY; + private static final DateTimeZone DATE_TIME_ZONE = getDateTimeZone(TIME_ZONE_KEY); + + @BeforeTest + private void prepareConfigFiles() + throws Exception + { + Set seeds = new HashSet<>(); + SeedStore mockSeedStore = mock(SeedStore.class); + Seed mockSeed = mock(Seed.class); + seeds.add(mockSeed); + + SeedStoreManager mockSeedStoreManager = mock(SeedStoreManager.class); + when(mockSeedStoreManager.getSeedStore(SeedStoreSubType.HAZELCAST)).thenReturn(mockSeedStore); + + when(mockSeed.getLocation()).thenReturn("127.0.0.1:7991"); + when(mockSeedStore.get()).thenReturn(seeds); + + StateStoreFactory factory = new HazelcastStateStoreFactory(); + stateStoreProvider = new LocalStateStoreProviderTest(mockSeedStoreManager); + stateStoreProvider.addStateStoreFactory(factory); + createStateStoreCluster("7991"); + stateStoreProvider.loadStateStore(); + } + + @AfterTest + private void cleanUp() + { + } + + private StateStore createStateStoreCluster(String port) + { + Map config = new HashMap<>(); + config.put("hazelcast.discovery.mode", "tcp-ip"); + config.put("state-store.cluster", "test-cluster"); + config.put(DISCOVERY_PORT_CONFIG_NAME, port); + + StateStoreBootstrapper bootstrapper = new HazelcastStateStoreBootstrapper(); + return bootstrapper.bootstrap(ImmutableSet.of("127.0.0.1:" + port), config); + } + + public TestScanFilterAndProjectOmniOperator() + { + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + } + + @Test + public void testPageSource() + { + final Page input = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + SequencePageBuilder.createSequencePage(ImmutableList.of(VARCHAR), 10_000, 0)); + DriverContext driverContext = newDriverContext(); + + List projections = ImmutableList.of(field(0, VARCHAR)); + Supplier cursorProcessor = omniExpressionCompiler.compileCursorProcessor(Optional.empty(), + projections, "key"); + Supplier pageProcessor = omniExpressionCompiler.compilePageProcessor(Optional.empty(), + projections); + + ScanFilterAndProjectOmniOperator.ScanFilterAndProjectOmniOperatorFactory factory = new ScanFilterAndProjectOmniOperator.ScanFilterAndProjectOmniOperatorFactory( + 0, new PlanNodeId("test"), new PlanNodeId("0"), + (session, split, table, columns, dynamicFilter) -> new FixedPageSource(ImmutableList.of(input)), + cursorProcessor, pageProcessor, TEST_TABLE_HANDLE, ImmutableList.of(), null, ImmutableList.of(VARCHAR), + new DataSize(0, BYTE), 0, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, new UUID(0, 0), false, + Optional.empty(), 0, 0, ImmutableList.of(VARCHAR)); + + SourceOperator operator = factory.createOperator(driverContext); + operator.addSplit(new Split(new CatalogName("test"), TestingSplit.createLocalSplit(), Lifespan.taskWide())); + operator.noMoreSplits(); + + MaterializedResult expected = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + ImmutableList.of(input)); + MaterializedResult actual = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + toPages(operator)); + + assertEquals(actual.getRowCount(), expected.getRowCount()); + assertEquals(actual, expected); + } + + @Test + public void testPageSourceMergeOutput() + { + List input = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + rowPagesBuilder(BIGINT).addSequencePage(100, 0).addSequencePage(100, 0).addSequencePage(100, 0) + .addSequencePage(100, 0).build()); + + RowExpression filter = call(EQUAL.getFunctionName().toString(), + new BuiltInFunctionHandle(Signature.internalOperator(EQUAL, BOOLEAN.getTypeSignature(), ImmutableList.of(BIGINT.getTypeSignature(), BIGINT.getTypeSignature()))), + BOOLEAN, + field(0, BIGINT), + constant(10L, BIGINT)); + List projections = ImmutableList.of(field(0, BIGINT)); + Supplier cursorProcessor = omniExpressionCompiler.compileCursorProcessor(Optional.of(filter), + projections, "key"); + Supplier pageProcessor = omniExpressionCompiler.compilePageProcessor(Optional.of(filter), + projections); + + ScanFilterAndProjectOmniOperator.ScanFilterAndProjectOmniOperatorFactory factory = new ScanFilterAndProjectOmniOperator.ScanFilterAndProjectOmniOperatorFactory( + 0, new PlanNodeId("test"), new PlanNodeId("0"), + (session, split, table, columns, dynamicFilter) -> new FixedPageSource(input), cursorProcessor, + pageProcessor, TEST_TABLE_HANDLE, ImmutableList.of(), null, ImmutableList.of(BIGINT), + new DataSize(64, KILOBYTE), 2, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, new UUID(0, 0), + false, Optional.empty(), 0, 0, ImmutableList.of(BIGINT)); + + SourceOperator operator = factory.createOperator(newDriverContext()); + operator.addSplit(new Split(new CatalogName("test"), TestingSplit.createLocalSplit(), Lifespan.taskWide())); + operator.noMoreSplits(); + + List actual = toPages(operator); + assertEquals(actual.size(), 1); + + List expected = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + rowPagesBuilder(BIGINT).row(10L).row(10L).row(10L).row(10L).build()); + + assertPageEquals(ImmutableList.of(BIGINT), actual.get(0), expected.get(0)); + + try { + operator.close(); + } + catch (Exception e) { + e.printStackTrace(); + } + } + + @Test + public void testRecordCursorSource() + { + final Page input = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + SequencePageBuilder.createSequencePage(ImmutableList.of(VARCHAR), 10_000, 0)); + DriverContext driverContext = newDriverContext(); + + List projections = ImmutableList.of(field(0, VARCHAR)); + Supplier cursorProcessor = omniExpressionCompiler.compileCursorProcessor(Optional.empty(), projections, "key"); + Supplier pageProcessor = omniExpressionCompiler.compilePageProcessor(Optional.empty(), projections); + + ScanFilterAndProjectOmniOperator.ScanFilterAndProjectOmniOperatorFactory factory = new ScanFilterAndProjectOmniOperator.ScanFilterAndProjectOmniOperatorFactory( + 0, new PlanNodeId("test"), new PlanNodeId("0"), + (session, split, table, columns, + dynamicFilter) -> new RecordPageSource(new PageRecordSet(ImmutableList.of(VARCHAR), input)), + cursorProcessor, pageProcessor, TEST_TABLE_HANDLE, ImmutableList.of(), null, ImmutableList.of(VARCHAR), + new DataSize(0, BYTE), 0, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, new UUID(0, 0), false, + Optional.empty(), 0, 0, ImmutableList.of(VARCHAR)); + + SourceOperator operator = factory.createOperator(driverContext); + operator.addSplit(new Split(new CatalogName("test"), TestingSplit.createLocalSplit(), Lifespan.taskWide())); + operator.noMoreSplits(); + + MaterializedResult expected = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + ImmutableList.of(input)); + MaterializedResult actual = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + toPages(operator)); + + assertEquals(actual.getRowCount(), expected.getRowCount()); + assertEquals(actual, expected); + } + + @Test + public void testRecordPageSource() + { + final Page input = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + SequencePageBuilder.createSequencePage(ImmutableList.of(VARCHAR), 10_000, 0)); + DriverContext driverContext = newDriverContext(); + + List projections = ImmutableList.of(field(0, VARCHAR)); + Supplier cursorProcessor = omniExpressionCompiler.compileCursorProcessor(Optional.empty(), + projections, "key"); + Supplier pageProcessor = omniExpressionCompiler.compilePageProcessor(Optional.empty(), + projections); + + Symbol symbol = new Symbol("test"); + Map assignments = new HashMap<>(); + assignments.put(symbol, new TestingColumnHandle("testingColumnHandle")); + TableScanNode tableScanNode = TableScanNode.newInstance(new PlanNodeId("tableScan"), TEST_TABLE_HANDLE, + ImmutableList.of(symbol), assignments, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, + new UUID(0, 0), 1, false); + Session localSession = Session.builder(session) + .setStartTime(new DateTime(2017, 3, 1, 14, 30, 0, 0, DATE_TIME_ZONE).getMillis()) + .setSystemProperty(ENABLE_CROSS_REGION_DYNAMIC_FILTER, "true").build(); + ScanFilterAndProjectOmniOperator.ScanFilterAndProjectOmniOperatorFactory factory = new ScanFilterAndProjectOmniOperator.ScanFilterAndProjectOmniOperatorFactory( + localSession, 0, new PlanNodeId("test"), tableScanNode, + (session, split, table, columns, + dynamicFilter) -> new RecordPageSource(new PageRecordSet(ImmutableList.of(VARCHAR), input)), + + cursorProcessor, pageProcessor, TEST_TABLE_HANDLE, ImmutableList.of(), null, ImmutableList.of(VARCHAR), + stateStoreProvider, functionAssertions.getMetadata(), new DynamicFilterCacheManager(), + new DataSize(0, BYTE), 0, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, new UUID(0, 0), false, + Optional.empty(), 0, 0, ImmutableList.of(VARCHAR)); + + SourceOperator operator = factory.createOperator(driverContext); + operator.addSplit(new Split(new CatalogName("test"), TestingSplit.createLocalSplit(), Lifespan.taskWide())); + operator.noMoreSplits(); + + MaterializedResult expected = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + ImmutableList.of(input)); + MaterializedResult actual = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + toPages(operator)); + + assertEquals(actual.getRowCount(), expected.getRowCount()); + assertEquals(actual, expected); + } + + @Test + public void testPageYield() + { + int totalRows = 1000; + Page input = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + SequencePageBuilder.createSequencePage(ImmutableList.of(BIGINT), totalRows, 1)); + DriverContext driverContext = newDriverContext(); + + // 20 columns; each column is associated with a function that will force yield + // per projection + int totalColumns = 20; + ImmutableList.Builder functions = ImmutableList.builder(); + for (int i = 0; i < totalColumns; i++) { + functions.add(new GenericLongFunction("page_col" + i, value -> { + driverContext.getYieldSignal().forceYieldForTesting(); + return value; + })); + functions.add(ApplyFunction.APPLY_FUNCTION); + } + Metadata localMetadata = functionAssertions.getMetadata(); + localMetadata.getFunctionAndTypeManager().registerBuiltInFunctions(functions.build()); + + // match each column with a projection + OmniExpressionCompiler omniCompiler = new OmniExpressionCompiler(localMetadata, + new PageFunctionCompiler(localMetadata, 0)); + ImmutableList.Builder projections = ImmutableList.builder(); + for (int i = 0; i < totalColumns; i++) { + projections.add(call(QualifiedObjectName.valueOfDefaultFunction("generic_long_page_col" + i).toString(), + new BuiltInFunctionHandle(internalScalarFunction( + QualifiedObjectName.valueOfDefaultFunction("generic_long_page_col" + i), + BIGINT.getTypeSignature(), ImmutableList.of(BIGINT.getTypeSignature()))), + BIGINT, field(0, BIGINT))); + } + Supplier cursorProcessor = omniCompiler.compileCursorProcessor(Optional.empty(), projections.build(), "key"); + Supplier pageProcessor = omniCompiler.compilePageProcessor(Optional.empty(), projections.build(), MAX_BATCH_SIZE); + + ScanFilterAndProjectOmniOperator.ScanFilterAndProjectOmniOperatorFactory factory = new ScanFilterAndProjectOmniOperator.ScanFilterAndProjectOmniOperatorFactory( + 0, new PlanNodeId("test"), new PlanNodeId("0"), + (session, split, table, columns, dynamicFilter) -> new FixedPageSource(ImmutableList.of(input)), + cursorProcessor, pageProcessor, TEST_TABLE_HANDLE, ImmutableList.of(), null, ImmutableList.of(BIGINT), + new DataSize(0, BYTE), 0, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, new UUID(0, 0), false, + Optional.empty(), 0, 0, ImmutableList.of(BIGINT)); + + SourceOperator operator = factory.createOperator(driverContext); + operator.addSplit(new Split(new CatalogName("test"), TestingSplit.createLocalSplit(), Lifespan.taskWide())); + operator.noMoreSplits(); + + // In the below loop we yield for every cell: 20 X 1000 times + // Currently we don't check for the yield signal in the generated projection + // loop, we only check for the yield signal + // in the PageProcessor.PositionsPageProcessorIterator::computeNext() method. + // Therefore, after 20 calls we will have + // exactly 20 blocks (one for each column) and the PageProcessor will be able to + // create a Page out of it. + for (int i = 1; i <= totalRows * totalColumns; i++) { + driverContext.getYieldSignal().setWithDelay(SECONDS.toNanos(1000), driverContext.getYieldExecutor()); + Page page = operator.getOutput(); + if (i == totalColumns) { + assertNotNull(page); + assertEquals(page.getPositionCount(), totalRows); + assertEquals(page.getChannelCount(), totalColumns); + for (int j = 0; j < totalColumns; j++) { + assertEquals(toValues(BIGINT, page.getBlock(j)), toValues(BIGINT, input.getBlock(0))); + } + } + else { + assertNull(page); + } + driverContext.getYieldSignal().reset(); + } + } + + @Test + public void testRecordCursorYield() + { + // create a generic long function that yields for projection on every row + // verify we will yield #row times totally + + // create a table with 15 rows + int length = 15; + Page input = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + SequencePageBuilder.createSequencePage(ImmutableList.of(BIGINT), length, 0)); + DriverContext driverContext = newDriverContext(); + + // set up generic long function with a callback to force yield + Metadata localMetadata = functionAssertions.getMetadata(); + localMetadata.getFunctionAndTypeManager() + .registerBuiltInFunctions(ImmutableList.of(new GenericLongFunction("record_cursor", value -> { + driverContext.getYieldSignal().forceYieldForTesting(); + return value; + }))); + OmniExpressionCompiler omniCompiler = new OmniExpressionCompiler(localMetadata, + new PageFunctionCompiler(localMetadata, 0)); + + List projections = ImmutableList + .of(call(QualifiedObjectName.valueOfDefaultFunction("generic_long_record_cursor").toString(), + new BuiltInFunctionHandle(internalScalarFunction( + QualifiedObjectName.valueOfDefaultFunction("generic_long_record_cursor"), + BIGINT.getTypeSignature(), ImmutableList.of(BIGINT.getTypeSignature()))), + BIGINT, field(0, BIGINT))); + Supplier cursorProcessor = omniCompiler.compileCursorProcessor(Optional.empty(), projections, "key"); + Supplier pageProcessor = omniCompiler.compilePageProcessor(Optional.empty(), projections); + + ScanFilterAndProjectOmniOperator.ScanFilterAndProjectOmniOperatorFactory factory = new ScanFilterAndProjectOmniOperator.ScanFilterAndProjectOmniOperatorFactory( + 0, new PlanNodeId("test"), new PlanNodeId("0"), + (session, split, table, columns, + dynamicFilter) -> new RecordPageSource(new PageRecordSet(ImmutableList.of(BIGINT), input)), + cursorProcessor, pageProcessor, TEST_TABLE_HANDLE, ImmutableList.of(), null, ImmutableList.of(BIGINT), + new DataSize(0, BYTE), 0, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, new UUID(0, 0), false, + Optional.empty(), 0, 0, ImmutableList.of(BIGINT)); + + SourceOperator operator = factory.createOperator(driverContext); + operator.addSplit(new Split(new CatalogName("test"), TestingSplit.createLocalSplit(), Lifespan.taskWide())); + operator.noMoreSplits(); + + // start driver; get null value due to yield for the first 15 times + for (int i = 0; i < length; i++) { + driverContext.getYieldSignal().setWithDelay(SECONDS.toNanos(1000), driverContext.getYieldExecutor()); + assertNull(operator.getOutput()); + driverContext.getYieldSignal().reset(); + } + + // the 16th yield is not going to prevent the operator from producing a page + driverContext.getYieldSignal().setWithDelay(SECONDS.toNanos(1000), driverContext.getYieldExecutor()); + Page output = operator.getOutput(); + driverContext.getYieldSignal().reset(); + assertNotNull(output); + assertEquals(toValues(BIGINT, output.getBlock(0)), toValues(BIGINT, input.getBlock(0))); + } + + @Test + public void testReusePageSource() + { + UUID uuid = UUID.randomUUID(); + final Page input = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + SequencePageBuilder.createSequencePage(ImmutableList.of(VARCHAR), 10_000, 0)); + DriverContext driverContext = newDriverContext(); + List producerPages; + + SourceOperator operatorProducer = createScanFilterAndProjectOmniOperator("0", uuid, 0, input, driverContext, + ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_PRODUCER, false, 10, 1); + + producerPages = toPages(operatorProducer); + MaterializedResult producerExpected = toMaterializedResult(driverContext.getSession(), + ImmutableList.of(VARCHAR), ImmutableList.of(input)); + MaterializedResult producerActual = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + producerPages); + assertEquals(producerActual.getRowCount(), producerExpected.getRowCount()); + assertEquals(producerActual, producerExpected); + + // Consumer + SourceOperator operatorConsumer = createScanFilterAndProjectOmniOperator("0", uuid, 1, input, driverContext, + ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_CONSUMER, false, 10, 1); + MaterializedResult consumerExpected = toMaterializedResult(driverContext.getSession(), + ImmutableList.of(VARCHAR), producerPages); + MaterializedResult consumerActual = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + toPages(operatorConsumer)); + assertEquals(consumerActual.getRowCount(), consumerExpected.getRowCount()); + assertEquals(consumerActual, consumerExpected); + } + + @Test + public void testReuseExchangeSpill() + { + UUID uuid = UUID.randomUUID(); + final Page input = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + SequencePageBuilder.createSequencePage(ImmutableList.of(VARCHAR), 10_000, 0)); + DriverContext driverContext = newDriverContext(); + List producerPages; + + SourceOperator operatorProducer = createScanFilterAndProjectOmniOperator("0", uuid, 0, input, driverContext, + ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_PRODUCER, true, 10, 1); + WorkProcessorSourceOperatorAdapter workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorProducer; + producerPages = toPages(operatorProducer); + + // check spilling is done + boolean notSpilled = workProcessorSourceOperatorAdapter.isNotSpilled(); + assertEquals(false, notSpilled); + MaterializedResult producerExpected = toMaterializedResult(driverContext.getSession(), + ImmutableList.of(VARCHAR), ImmutableList.of(input)); + MaterializedResult producerActual = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + producerPages); + assertEquals(producerActual.getRowCount(), producerExpected.getRowCount()); + assertEquals(producerActual, producerExpected); + + // Consumer + SourceOperator operatorConsumer = createScanFilterAndProjectOmniOperator("0", uuid, 1, input, driverContext, + ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_CONSUMER, true, 10, 1); + MaterializedResult consumerExpected = toMaterializedResult(driverContext.getSession(), + ImmutableList.of(VARCHAR), producerPages); + MaterializedResult consumerActual = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + toPages(operatorConsumer)); + assertEquals(consumerActual.getRowCount(), consumerExpected.getRowCount()); + assertEquals(consumerActual, consumerExpected); + } + + @Test + public void testReuseExchangeInMemorySpill() + { + UUID uuid = UUID.randomUUID(); + final Page input = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + SequencePageBuilder.createSequencePage(ImmutableList.of(VARCHAR), 10_000, 0)); + DriverContext driverContext = newDriverContext(); + List producerPages; + + SourceOperator operatorProducer = createScanFilterAndProjectOmniOperator("0", uuid, 0, input, driverContext, + ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_PRODUCER, true, 75000, 1); + WorkProcessorSourceOperatorAdapter workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorProducer; + producerPages = toPages(operatorProducer); + + // check spilling is done + boolean notSpilled = workProcessorSourceOperatorAdapter.isNotSpilled(); + assertEquals(true, notSpilled); + MaterializedResult producerExpected = toMaterializedResult(driverContext.getSession(), + ImmutableList.of(VARCHAR), ImmutableList.of(input)); + MaterializedResult producerActual = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + producerPages); + assertEquals(producerActual.getRowCount(), producerExpected.getRowCount()); + assertEquals(producerActual, producerExpected); + + // Consumer + SourceOperator operatorConsumer = createScanFilterAndProjectOmniOperator("0", uuid, 1, input, driverContext, + ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_CONSUMER, true, 75000, 0); + workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorConsumer; + boolean flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(false, flag); + MaterializedResult consumerExpected = toMaterializedResult(driverContext.getSession(), + ImmutableList.of(VARCHAR), producerPages); + MaterializedResult consumerActual = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + toPages(operatorConsumer)); + assertEquals(consumerActual.getRowCount(), consumerExpected.getRowCount()); + assertEquals(consumerActual, consumerExpected); + + workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorConsumer; + flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(true, flag); + } + + @Test + public void testReuseExchangeMultipleConsumer() + { + UUID uuid = UUID.randomUUID(); + final Page input = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + SequencePageBuilder.createSequencePage(ImmutableList.of(VARCHAR), 10_000, 0)); + DriverContext driverContext = newDriverContext(); + List producerPages; + + SourceOperator operatorProducer = createScanFilterAndProjectOmniOperator("0", uuid, 0, input, driverContext, + ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_PRODUCER, true, 800000, 2); + WorkProcessorSourceOperatorAdapter workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorProducer; + producerPages = toPages(operatorProducer); + + // check spilling is done + boolean notSpilled = workProcessorSourceOperatorAdapter.isNotSpilled(); + assertEquals(true, notSpilled); + MaterializedResult producerExpected = toMaterializedResult(driverContext.getSession(), + ImmutableList.of(VARCHAR), ImmutableList.of(input)); + MaterializedResult producerActual = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + producerPages); + assertEquals(producerActual.getRowCount(), producerExpected.getRowCount()); + assertEquals(producerActual, producerExpected); + + // Consumer 1 + SourceOperator operatorConsumer = createScanFilterAndProjectOmniOperator("1", uuid, 1, input, driverContext, + ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_CONSUMER, true, 800000, 0); + workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorConsumer; + boolean flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(false, flag); + MaterializedResult consumerExpected = toMaterializedResult(driverContext.getSession(), + ImmutableList.of(VARCHAR), producerPages); + MaterializedResult consumerActual = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + toPages(operatorConsumer)); + assertEquals(consumerActual.getRowCount(), consumerExpected.getRowCount()); + assertEquals(consumerActual, consumerExpected); + flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(true, flag); + + // Consumer 2 + SourceOperator operatorConsumer1 = createScanFilterAndProjectOmniOperator("2", uuid, 2, input, driverContext, + ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_CONSUMER, true, 80000, 0); + workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorConsumer1; + flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(false, flag); + MaterializedResult consumerExpected1 = toMaterializedResult(driverContext.getSession(), + ImmutableList.of(VARCHAR), producerPages); + MaterializedResult consumerActual1 = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + toPages(operatorConsumer1)); + assertEquals(consumerActual1.getRowCount(), consumerExpected1.getRowCount()); + assertEquals(consumerActual1, consumerExpected1); + + workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorConsumer1; + flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(true, flag); + } + + @Test + public void testReuseExchangeInMemorySpillMultipleConsumer() + { + UUID uuid = UUID.randomUUID(); + final Page input = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + SequencePageBuilder.createSequencePage(ImmutableList.of(VARCHAR), 10_000, 0)); + DriverContext driverContext = newDriverContext(); + List producerPages; + + SourceOperator operatorProducer = createScanFilterAndProjectOmniOperator("0", uuid, 0, input, driverContext, + ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_PRODUCER, true, 75000, 2); + WorkProcessorSourceOperatorAdapter workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorProducer; + producerPages = toPages(operatorProducer); + + // check spilling is done + boolean notSpilled = workProcessorSourceOperatorAdapter.isNotSpilled(); + assertEquals(true, notSpilled); + MaterializedResult producerExpected = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + ImmutableList.of(input)); + MaterializedResult producerActual = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + producerPages); + assertEquals(producerActual.getRowCount(), producerExpected.getRowCount()); + assertEquals(producerActual, producerExpected); + + // Consumer 1 + SourceOperator operatorConsumer = createScanFilterAndProjectOmniOperator("1", uuid, 1, input, driverContext, + ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_CONSUMER, true, 75000, 0); + workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorConsumer; + boolean flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(false, flag); + List consumer1Pages = new ArrayList<>(); + // read pages 1st time from pageCaches + consumer1Pages.addAll(toPages(operatorConsumer, true)); + flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(false, flag); + + // Consumer 2 + SourceOperator operatorConsumer2 = createScanFilterAndProjectOmniOperator("2", uuid, 2, input, driverContext, + ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_CONSUMER, true, 75000, 0); + workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorConsumer2; + flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(false, flag); + List consumer2Pages = new ArrayList<>(); + // read pages 1st time from pageCaches + consumer2Pages.addAll(toPages(operatorConsumer2)); + assertEquals(producerPages.size(), consumer2Pages.size()); + + // read pages 2nd time after pages are moved form pagesToSpill to pageCaches + consumer1Pages.addAll(toPages(operatorConsumer)); + assertEquals(producerPages.size(), consumer1Pages.size()); + + // consumer1 should be finished after reading from pagesToSpill & pageCaches + workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorConsumer; + flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(true, flag); + + // consumer2 should be finished after reading from pagesToSpill & pageCaches + workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorConsumer2; + flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(true, flag); + } + + @Test + public void testReuseExchangeSpillToDiskMultipleConsumer() + { + UUID uuid = UUID.randomUUID(); + final Page input = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, + SequencePageBuilder.createSequencePage(ImmutableList.of(VARCHAR), 10_000, 0)); + DriverContext driverContext = newDriverContext(); + List producerPages; + + SourceOperator operatorProducer = createScanFilterAndProjectOmniOperator("0", uuid, 0, input, driverContext, + ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_PRODUCER, true, 4000, 2); + WorkProcessorSourceOperatorAdapter workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorProducer; + producerPages = toPages(operatorProducer); + + // check spilling is done + boolean notSpilled = workProcessorSourceOperatorAdapter.isNotSpilled(); + assertEquals(false, notSpilled); + MaterializedResult producerExpected = toMaterializedResult(driverContext.getSession(), + ImmutableList.of(VARCHAR), ImmutableList.of(input)); + MaterializedResult producerActual = toMaterializedResult(driverContext.getSession(), ImmutableList.of(VARCHAR), + producerPages); + assertEquals(producerActual.getRowCount(), producerExpected.getRowCount()); + assertEquals(producerActual, producerExpected); + + // Consumer 1 + SourceOperator operatorConsumer = createScanFilterAndProjectOmniOperator("1", uuid, 1, input, driverContext, + ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_CONSUMER, true, 4000, 0); + workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorConsumer; + boolean flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(false, flag); + List consumer1Pages = new ArrayList<>(); + // read pages 1st time from pageCaches + consumer1Pages.addAll(toPages(operatorConsumer, true)); + flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(false, flag); + + // Consumer 2 + SourceOperator operatorConsumer2 = createScanFilterAndProjectOmniOperator("2", uuid, 2, input, driverContext, + ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_CONSUMER, true, 4000, 0); + workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorConsumer2; + flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(false, flag); + List consumer2Pages = new ArrayList<>(); + // since its last consumer it will read pages pageCaches & pagesToSpill at a + // time + consumer2Pages.addAll(toPages(operatorConsumer2)); + assertEquals(producerPages.size(), consumer2Pages.size()); + + // consumer2 should be finished after reading from pagesToSpill & pageCaches + workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorConsumer2; + flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(true, flag); + + // read pages 2nd time after pages are moved form pagesToSpill to pageCaches + consumer1Pages.addAll(toPages(operatorConsumer)); + assertEquals(producerPages.size(), consumer1Pages.size()); + + // consumer1 should be finished after reading from pagesToSpill & pageCaches + workProcessorSourceOperatorAdapter = (WorkProcessorSourceOperatorAdapter) operatorConsumer; + flag = getWorkProcessorSourceOperatorAdapterCheckFinished(workProcessorSourceOperatorAdapter); + assertEquals(true, flag); + } + + private boolean getWorkProcessorSourceOperatorAdapterCheckFinished(WorkProcessorSourceOperatorAdapter workProcessorSourceOperatorAdapter) + { + boolean returnValue = false; + try { + Method privateStringMethod = WorkProcessorSourceOperatorAdapter.class.getDeclaredMethod("checkFinished", null); + privateStringMethod.setAccessible(true); + returnValue = (boolean) privateStringMethod.invoke(workProcessorSourceOperatorAdapter, null); + } + catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + // the exception could be ignored + } + return returnValue; + } + + private SourceOperator createScanFilterAndProjectOmniOperator(String sourceId, UUID uuid, int operatorId, Page input, DriverContext driverContext, + ReuseExchangeOperator.STRATEGY strategy, boolean spillEnabled, Integer spillerThreshold, + Integer consumerTableScanNodeCount) + { + List projections = ImmutableList.of(field(0, VARCHAR)); + Supplier cursorProcessor = omniExpressionCompiler.compileCursorProcessor(Optional.empty(), + projections, "key"); + Supplier pageProcessor = omniExpressionCompiler.compilePageProcessor(Optional.empty(), + projections); + + ScanFilterAndProjectOmniOperator.ScanFilterAndProjectOmniOperatorFactory factory = new ScanFilterAndProjectOmniOperator.ScanFilterAndProjectOmniOperatorFactory( + operatorId, new PlanNodeId("test"), new PlanNodeId(sourceId), + (session, split, table, columns, + dynamicFilter) -> new FixedPageSource(ImmutableList.of(input)), + cursorProcessor, pageProcessor, TEST_TABLE_HANDLE, ImmutableList.of(), null, ImmutableList.of(VARCHAR), + new DataSize(0, BYTE), 0, strategy, uuid, spillEnabled, Optional.of(new DummySpillerFactory()), + spillerThreshold, consumerTableScanNodeCount, ImmutableList.of(VARCHAR)); + + SourceOperator operator = factory.createOperator(driverContext); + operator.addSplit(new Split(new CatalogName("test"), TestingSplit.createLocalSplit(), Lifespan.taskWide())); + operator.noMoreSplits(); + return operator; + } + + private static List toPages(Operator operator) + { + return toPages(operator, false); + } + + private static List toPages(Operator operator, boolean retNotFinished) + { + ImmutableList.Builder outputPages = ImmutableList.builder(); + + // read output until input is needed or operator is finished + int nullPages = 0; + while (!operator.isFinished()) { + Page outputPage = operator.getOutput(); + if (outputPage == null) { + if (retNotFinished) { + return outputPages.build(); + } + // break infinite loop due to null pages + assertTrue(nullPages < 1_000_000, "Too many null pages; infinite loop?"); + nullPages++; + } + else { + outputPages.add(outputPage); + nullPages = 0; + } + } + + return outputPages.build(); + } + + private DriverContext newDriverContext() + { + return createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false) + .addDriverContext(); + } + + public static class SinglePagePageSource + implements ConnectorPageSource + { + private Page page; + + public SinglePagePageSource(Page page) + { + this.page = page; + } + + @Override + public void close() + { + page = null; + } + + @Override + public long getCompletedBytes() + { + return 0; + } + + @Override + public long getReadTimeNanos() + { + return 0; + } + + @Override + public long getSystemMemoryUsage() + { + return 0; + } + + @Override + public boolean isFinished() + { + return page == null; + } + + @Override + public Page getNextPage() + { + Page tmpPage = this.page; + this.page = null; + return tmpPage; + } + } + + static class GenericLongFunction + extends SqlScalarFunction + { + private final MethodHandle methodHandle = methodHandle(io.prestosql.operator.GenericLongFunction.class, + "apply", LongUnaryOperator.class, long.class); + + private final LongUnaryOperator longUnaryOperator; + + GenericLongFunction(String suffix, LongUnaryOperator longUnaryOperator) + { + super(internalScalarFunction( + QualifiedObjectName + .valueOfDefaultFunction("generic_long_" + requireNonNull(suffix, "suffix is null")), + parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.BIGINT))); + this.longUnaryOperator = longUnaryOperator; + } + + @Override + public boolean isHidden() + { + return true; + } + + @Override + public boolean isDeterministic() + { + return true; + } + + @Override + public String getDescription() + { + return "generic long function for test"; + } + + @Override + public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, + FunctionAndTypeManager functionAndTypeManager) + { + MethodHandle methodHandle = this.methodHandle.bindTo(longUnaryOperator); + return new BuiltInScalarFunctionImplementation(false, + ImmutableList.of(valueTypeArgumentProperty(RETURN_NULL_ON_NULL)), methodHandle); + } + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestTopNOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestTopNOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..60f0acfdf9371bcbefdd491b49e46522c90c114b --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestTopNOmniOperator.java @@ -0,0 +1,201 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import io.airlift.units.DataSize; +import io.prestosql.ExceededMemoryLimitException; +import io.prestosql.operator.DriverContext; +import io.prestosql.operator.Operator; +import io.prestosql.operator.OperatorFactory; +import io.prestosql.spi.Page; +import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.testing.MaterializedResult; +import nova.hetu.olk.block.DictionaryOmniBlock; +import nova.hetu.olk.block.IntArrayOmniBlock; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; + +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.airlift.units.DataSize.Unit.BYTE; +import static io.prestosql.RowPagesBuilder.rowPagesBuilder; +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.operator.OperatorAssertion.assertOperatorEquals; +import static io.prestosql.spi.block.SortOrder.ASC_NULLS_LAST; +import static io.prestosql.spi.block.SortOrder.DESC_NULLS_LAST; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static io.prestosql.testing.MaterializedResult.resultBuilder; +import static io.prestosql.testing.TestingTaskContext.createTaskContext; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static nova.hetu.olk.tool.OperatorUtils.transferToOffHeapPages; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.fail; + +@Test(singleThreaded = true) +public class TestTopNOmniOperator +{ + private ExecutorService executor; + private ScheduledExecutorService scheduledExecutor; + private DriverContext driverContext; + + @BeforeMethod + public void setUp() + { + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false).addDriverContext(); + } + + @AfterMethod(alwaysRun = true) + public void tearDown() + { + executor.shutdownNow(); + scheduledExecutor.shutdownNow(); + } + + @Test + public void testSingleFieldKey() + { + List input = rowPagesBuilder(BIGINT, DOUBLE).row(1L, 0.1).row(2L, 0.2).pageBreak().row(-1L, -0.1) + .row(4L, 0.4).pageBreak().row(5L, 0.5).row(4L, 0.41).row(6L, 0.6).pageBreak().build(); + // transfer on-heap page to off-heap + List offHeapInput = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + OperatorFactory operatorFactory = new TopNOmniOperator.TopNOmniOperatorFactory(0, new PlanNodeId("test"), + ImmutableList.of(BIGINT, DOUBLE), 2, ImmutableList.of(0), ImmutableList.of(DESC_NULLS_LAST)); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, DOUBLE).row(6L, 0.6) + .row(5L, 0.5).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapInput, expected); + } + + @Test(enabled = false) + public void testSingleFieldKeyDictionaryVec() + { + int[] ints = {2, 1, 4, 3, 2}; + IntArrayOmniBlock intArrayOmniBlock = new IntArrayOmniBlock(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, ints.length, + Optional.empty(), ints); + int[] ids = {0, 1, 2, 3, 4}; + DictionaryOmniBlock integerDictionaryOmniBlock = new DictionaryOmniBlock( + (Vec) intArrayOmniBlock.getValues(), ids); + Page input = new Page(integerDictionaryOmniBlock); + + OperatorFactory operatorFactory = new TopNOmniOperator.TopNOmniOperatorFactory(0, new PlanNodeId("test"), + ImmutableList.of(BIGINT, DOUBLE), 2, ImmutableList.of(0), ImmutableList.of(DESC_NULLS_LAST)); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, DOUBLE).row(6L, 0.6) + .row(5L, 0.5).build(); + + assertOperatorEquals(operatorFactory, driverContext, Collections.singletonList(input), expected); + } + + @Test + public void testMultiFieldKey() + { + List input = rowPagesBuilder(INTEGER, BIGINT).row(0, 1L).row(1, 2L).pageBreak().row(5, 3L).row(0, 4L) + .pageBreak().row(3, 5L).row(3, 7L).row(4, 6L).build(); + + List offHeapInput = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + OperatorFactory operatorFactory = new TopNOmniOperator.TopNOmniOperatorFactory(0, new PlanNodeId("test"), + ImmutableList.of(INTEGER, BIGINT), 3, ImmutableList.of(0, 1), + ImmutableList.of(DESC_NULLS_LAST, DESC_NULLS_LAST)); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), INTEGER, BIGINT).row(5, 3L) + .row(4, 6L).row(3, 7L).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapInput, expected); + } + + @Test(enabled = false) + public void testMultiFieldKeyWithVarChar() + { + List input = rowPagesBuilder(VARCHAR, BIGINT).row("a", 1L).row("b", 2L).pageBreak().row("f", 3L) + .row("a", 4L).pageBreak().row("d", 5L).row("d", 7L).row("e", 6L).build(); + + OperatorFactory operatorFactory = new TopNOmniOperator.TopNOmniOperatorFactory(0, new PlanNodeId("test"), + ImmutableList.of(VARCHAR, BIGINT), 3, ImmutableList.of(0, 1), + ImmutableList.of(DESC_NULLS_LAST, DESC_NULLS_LAST)); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), VARCHAR, BIGINT).row("f", 3L) + .row("e", 6L).row("d", 7L).build(); + + assertOperatorEquals(operatorFactory, driverContext, input, expected); + } + + @Test + public void testReverseOrder() + { + List input = rowPagesBuilder(BIGINT, DOUBLE).row(1L, 0.1).row(2L, 0.2).pageBreak().row(-1L, -0.1) + .row(4L, 0.4).pageBreak().row(5L, 0.5).row(4L, 0.41).row(6L, 0.6).pageBreak().build(); + + List offHeapInput = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + OperatorFactory operatorFactory = new TopNOmniOperator.TopNOmniOperatorFactory(0, new PlanNodeId("test"), + ImmutableList.of(BIGINT, DOUBLE), 2, ImmutableList.of(0), ImmutableList.of(ASC_NULLS_LAST)); + + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, DOUBLE).row(-1L, -0.1) + .row(1L, 0.1).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapInput, expected); + } + + @Test + public void testLimitZero() throws Exception + { + OperatorFactory operatorFactory = new TopNOmniOperator.TopNOmniOperatorFactory(0, new PlanNodeId("test"), + ImmutableList.of(BIGINT), 0, ImmutableList.of(0), ImmutableList.of(DESC_NULLS_LAST)); + try (Operator operator = operatorFactory.createOperator(driverContext)) { + assertNull(operator.getOutput()); + assertFalse(operator.needsInput()); + assertNull(operator.getOutput()); + } + } + + @Test(enabled = false) + public void testExceedMemoryLimit() throws Exception + { + List input = rowPagesBuilder(BIGINT).row(1L).build(); + + DriverContext smallDriverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION, + new DataSize(1, BYTE)).addPipelineContext(0, true, true, false).addDriverContext(); + OperatorFactory operatorFactory = new TopNOmniOperator.TopNOmniOperatorFactory(0, new PlanNodeId("test"), + ImmutableList.of(BIGINT), 100, ImmutableList.of(0), ImmutableList.of(ASC_NULLS_LAST)); + try (Operator operator = operatorFactory.createOperator(smallDriverContext)) { + operator.addInput(input.get(0)); + operator.getOutput(); + fail("must fail because of exceeding local memory limit"); + } + catch (ExceededMemoryLimitException ignore) { + } + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestWindowOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestWindowOmniOperator.java new file mode 100644 index 0000000000000000000000000000000000000000..e153d3721d5aa71e9015b906dcfe30b3fe9d06a2 --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/operator/TestWindowOmniOperator.java @@ -0,0 +1,456 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.operator; + +import com.google.common.collect.ImmutableList; +import io.prestosql.metadata.Metadata; +import io.prestosql.operator.DriverContext; +import io.prestosql.operator.WindowFunctionDefinition; +import io.prestosql.operator.window.AggregateWindowFunction; +import io.prestosql.operator.window.FrameInfo; +import io.prestosql.operator.window.RankFunction; +import io.prestosql.operator.window.ReflectionWindowFunctionSupplier; +import io.prestosql.operator.window.RowNumberFunction; +import io.prestosql.spi.Page; +import io.prestosql.spi.block.SortOrder; +import io.prestosql.spi.connector.QualifiedObjectName; +import io.prestosql.spi.function.FunctionKind; +import io.prestosql.spi.function.Signature; +import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.type.BigintType; +import io.prestosql.spi.type.Type; +import io.prestosql.testing.MaterializedResult; +import io.prestosql.testing.TestingTaskContext; +import nova.hetu.olk.tool.OperatorUtils; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; +import org.testng.internal.collections.Ints; + +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; + +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.airlift.units.DataSize.succinctBytes; +import static io.prestosql.RowPagesBuilder.rowPagesBuilder; +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.operator.OperatorAssertion.assertOperatorEquals; +import static io.prestosql.operator.WindowFunctionDefinition.window; +import static io.prestosql.spi.function.FunctionKind.AGGREGATE; +import static io.prestosql.spi.sql.expression.Types.FrameBoundType.UNBOUNDED_FOLLOWING; +import static io.prestosql.spi.sql.expression.Types.FrameBoundType.UNBOUNDED_PRECEDING; +import static io.prestosql.spi.sql.expression.Types.WindowFrameType.RANGE; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.testing.MaterializedResult.resultBuilder; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestWindowOmniOperator +{ + private static final Metadata METADATA = createTestMetadataManager(); + private static final FrameInfo UNBOUNDED_FRAME = new FrameInfo(RANGE, UNBOUNDED_PRECEDING, Optional.empty(), + UNBOUNDED_FOLLOWING, Optional.empty()); + private static final List RANK = ImmutableList + .of(window(new ReflectionWindowFunctionSupplier<>("rank", BIGINT, ImmutableList.of(), RankFunction.class), + BIGINT, UNBOUNDED_FRAME)); + private static final List ROW_NUMBER = ImmutableList.of(window( + new ReflectionWindowFunctionSupplier<>("row_number", BIGINT, ImmutableList.of(), RowNumberFunction.class), + BIGINT, UNBOUNDED_FRAME)); + private static final List AVG = ImmutableList.of(window( + AggregateWindowFunction.supplier(new Signature( + QualifiedObjectName.valueOfDefaultFunction("avg"), FunctionKind.AGGREGATE, + DOUBLE.getTypeSignature(), BIGINT.getTypeSignature()), + METADATA.getFunctionAndTypeManager().getAggregateFunctionImplementation( + new Signature(QualifiedObjectName.valueOfDefaultFunction("avg"), AGGREGATE, + DOUBLE.getTypeSignature(), BIGINT.getTypeSignature()))), + BIGINT, UNBOUNDED_FRAME, 1)); + + private static final List SUM = ImmutableList.of(window( + AggregateWindowFunction.supplier(new Signature( + QualifiedObjectName.valueOfDefaultFunction("sum"), FunctionKind.AGGREGATE, + BigintType.BIGINT.getTypeSignature(), BIGINT.getTypeSignature()), + METADATA.getFunctionAndTypeManager().getAggregateFunctionImplementation( + new Signature(QualifiedObjectName.valueOfDefaultFunction("sum"), AGGREGATE, + BIGINT.getTypeSignature(), BIGINT.getTypeSignature()))), + BIGINT, UNBOUNDED_FRAME, 1)); + private static final List MAX = ImmutableList.of(window( + AggregateWindowFunction.supplier(new Signature( + QualifiedObjectName.valueOfDefaultFunction("max"), FunctionKind.AGGREGATE, + BigintType.BIGINT.getTypeSignature(), BIGINT.getTypeSignature()), + METADATA.getFunctionAndTypeManager().getAggregateFunctionImplementation( + new Signature(QualifiedObjectName.valueOfDefaultFunction("max"), AGGREGATE, + BIGINT.getTypeSignature(), BIGINT.getTypeSignature()))), + BIGINT, UNBOUNDED_FRAME, 1)); + private static final List MIN = ImmutableList.of(window( + AggregateWindowFunction.supplier(new Signature( + QualifiedObjectName.valueOfDefaultFunction("min"), FunctionKind.AGGREGATE, + BigintType.BIGINT.getTypeSignature(), BIGINT.getTypeSignature()), + METADATA.getFunctionAndTypeManager().getAggregateFunctionImplementation( + new Signature(QualifiedObjectName.valueOfDefaultFunction("min"), AGGREGATE, + BIGINT.getTypeSignature(), BIGINT.getTypeSignature()))), + BIGINT, UNBOUNDED_FRAME, 1)); + private static final List COUNT_COLUMN = ImmutableList.of(window( + AggregateWindowFunction.supplier(new Signature( + QualifiedObjectName.valueOfDefaultFunction("count"), FunctionKind.AGGREGATE, + BigintType.BIGINT.getTypeSignature(), BIGINT.getTypeSignature()), + METADATA.getFunctionAndTypeManager().getAggregateFunctionImplementation( + new Signature(QualifiedObjectName.valueOfDefaultFunction("count"), AGGREGATE, + BIGINT.getTypeSignature(), BIGINT.getTypeSignature()))), + BIGINT, UNBOUNDED_FRAME, 1)); + private static final List COUNT_ALL = ImmutableList.of(window( + AggregateWindowFunction.supplier(new Signature( + QualifiedObjectName.valueOfDefaultFunction("count"), FunctionKind.AGGREGATE, + BigintType.BIGINT.getTypeSignature(), BIGINT.getTypeSignature()), + METADATA.getFunctionAndTypeManager().getAggregateFunctionImplementation( + new Signature(QualifiedObjectName.valueOfDefaultFunction("count"), AGGREGATE, + BIGINT.getTypeSignature(), BIGINT.getTypeSignature()))), + BIGINT, UNBOUNDED_FRAME, 1)); + private static final List RANK_AND_ROW_NUMBER = new ImmutableList.Builder() + .add(window(new ReflectionWindowFunctionSupplier<>("rank", BIGINT, ImmutableList.of(), RankFunction.class), + BIGINT, UNBOUNDED_FRAME)) + .add(window(new ReflectionWindowFunctionSupplier<>("row_number", BIGINT, ImmutableList.of(), RowNumberFunction.class), + BIGINT, UNBOUNDED_FRAME)) + .build(); + private ExecutorService executor; + private ScheduledExecutorService scheduledExecutor; + + @BeforeMethod + public void setUp() + { + executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s")); + } + + @AfterMethod(alwaysRun = true) + public void tearDown() + { + executor.shutdownNow(); + scheduledExecutor.shutdownNow(); + } + + @Test + public void testRankPartition() + { + List input = rowPagesBuilder(INTEGER, BIGINT, DOUBLE).row(2, -1L, -0.1).row(1, 2L, 0.3).row(1, 4L, 0.2) + .pageBreak().row(2, 5L, 0.4).row(1, 6L, 0.1).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + WindowOmniOperator.WindowOmniOperatorFactory operatorFactory = new WindowOmniOperator.WindowOmniOperatorFactory( + 0, new PlanNodeId("test"), + ImmutableList.of(INTEGER, BIGINT, DOUBLE), Ints.asList(0, 1, 2), RANK, + Ints.asList(0), Ints.asList(), Ints.asList(1), + ImmutableList.copyOf(new SortOrder[]{SortOrder.ASC_NULLS_LAST}), 0, 10000); + + DriverContext driverContext = createDriverContext(); + MaterializedResult expected = resultBuilder(driverContext.getSession(), INTEGER, BIGINT, DOUBLE, BIGINT) + .row(1, 2L, 0.3, 1L).row(1, 4L, 0.2, 2L).row(1, 6L, 0.1, 3L).row(2, -1L, -0.1, 1L).row(2, 5L, 0.4, 2L) + .build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapPages, expected, false); + } + + @Test + public void testRankPartitionDiffLayout() + { + List input = rowPagesBuilder(INTEGER, BIGINT, DOUBLE).row(2, -1L, -0.1).row(1, 2L, 0.3).row(1, 4L, 0.2) + .pageBreak().row(2, 5L, 0.4).row(1, 6L, 0.1).row(1, 4L, 0.2).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + WindowOmniOperator.WindowOmniOperatorFactory operatorFactory = new WindowOmniOperator.WindowOmniOperatorFactory( + 0, new PlanNodeId("test"), + ImmutableList.of(INTEGER, BIGINT, DOUBLE), Ints.asList(2, 0, 1), RANK, + Ints.asList(0), Ints.asList(), Ints.asList(1), + ImmutableList.copyOf(new SortOrder[]{SortOrder.ASC_NULLS_LAST}), 0, 10000); + + DriverContext driverContext = createDriverContext(); + MaterializedResult expected = resultBuilder(driverContext.getSession(), DOUBLE, INTEGER, BIGINT, BIGINT) + .row(0.3, 1, 2L, 1L).row(0.2, 1, 4L, 2L).row(0.2, 1, 4L, 2L).row(0.1, 1, 6L, 4L).row(-0.1, 2, -1L, 1L) + .row(0.4, 2, 5L, 2L).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapPages, expected, false); + } + + @Test + public void testRowNumberPartition() + { + List input = rowPagesBuilder(INTEGER, BIGINT, DOUBLE).row(2, -1L, -0.1).row(1, 2L, 0.3).row(1, 4L, 0.2) + .pageBreak().row(2, 5L, 0.4).row(1, 6L, 0.1).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + WindowOmniOperator.WindowOmniOperatorFactory operatorFactory = new WindowOmniOperator.WindowOmniOperatorFactory( + 0, new PlanNodeId("test"), + ImmutableList.of(INTEGER, BIGINT, DOUBLE), Ints.asList(0, 2, 1), ROW_NUMBER, + Ints.asList(0), Ints.asList(), Ints.asList(1), + ImmutableList.copyOf(new SortOrder[]{SortOrder.ASC_NULLS_LAST}), 0, 10000); + + DriverContext driverContext = createDriverContext(); + MaterializedResult expected = resultBuilder(driverContext.getSession(), INTEGER, DOUBLE, BIGINT, BIGINT) + .row(1, 0.3, 2L, 1L).row(1, 0.2, 4L, 2L).row(1, 0.1, 6L, 3L).row(2, -0.1, -1L, 1L).row(2, 0.4, 5L, 2L) + .build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapPages, expected, false); + } + + @Test + public void testRankAndRowNumberPartition() + { + List input = rowPagesBuilder(INTEGER, BIGINT, DOUBLE).row(2, -1L, -0.1).row(1, 2L, 0.3).row(1, 4L, 0.2) + .row(1, 4L, 0.2).pageBreak().row(2, 5L, 0.4).row(1, 6L, 0.1).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + WindowOmniOperator.WindowOmniOperatorFactory operatorFactory = new WindowOmniOperator.WindowOmniOperatorFactory( + 0, new PlanNodeId("test"), + ImmutableList.of(INTEGER, BIGINT, DOUBLE), Ints.asList(0, 2, 1), + RANK_AND_ROW_NUMBER, Ints.asList(0), Ints.asList(), Ints.asList(1), + ImmutableList.copyOf(new SortOrder[]{SortOrder.ASC_NULLS_LAST}), 0, 10000); + + DriverContext driverContext = createDriverContext(); + MaterializedResult expected = resultBuilder(driverContext.getSession(), INTEGER, DOUBLE, BIGINT, BIGINT, BIGINT) + .row(1, 0.3, 2L, 1L, 1L).row(1, 0.2, 4L, 2L, 2L).row(1, 0.2, 4L, 2L, 3L).row(1, 0.1, 6L, 4L, 4L) + .row(2, -0.1, -1L, 1L, 1L).row(2, 0.4, 5L, 2L, 2L).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapPages, expected, false); + } + + @Test + public void testAvgPartitionWithoutSort() + { + List input = rowPagesBuilder(INTEGER, BIGINT, DOUBLE).row(2, -1L, -0.1).row(1, 2L, 0.3).row(1, 4L, 0.2) + .pageBreak().row(2, 5L, 0.4).row(1, 6L, 0.1).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + WindowOmniOperator.WindowOmniOperatorFactory operatorFactory = new WindowOmniOperator.WindowOmniOperatorFactory( + 0, new PlanNodeId("test"), ImmutableList.of(INTEGER, BIGINT, DOUBLE), Ints.asList(0, 1, 2), AVG, + Ints.asList(0), Ints.asList(), Ints.asList(), ImmutableList.copyOf(new SortOrder[]{}), 0, 10000); + + List sourceTypes = operatorFactory.getSourceTypes(); + assertEquals(sourceTypes, ImmutableList.of(INTEGER, BIGINT, DOUBLE)); + assertTrue(operatorFactory.isExtensionOperatorFactory()); + + DriverContext driverContext = createDriverContext(); + MaterializedResult expected = resultBuilder(driverContext.getSession(), INTEGER, BIGINT, DOUBLE, DOUBLE) + .row(1, 2L, 0.3, 4.0).row(1, 4L, 0.2, 4.0).row(1, 6L, 0.1, 4.0).row(2, -1L, -0.1, 2.0) + .row(2, 5L, 0.4, 2.0).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapPages, expected, false); + } + + @Test + public void testAvgPartition() + { + List input = rowPagesBuilder(INTEGER, BIGINT, DOUBLE).row(2, -1L, -0.1).row(1, 2L, 0.3).row(1, 4L, 0.2) + .pageBreak().row(2, 5L, 0.4).row(1, 6L, 0.1).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + WindowOmniOperator.WindowOmniOperatorFactory operatorFactory = new WindowOmniOperator.WindowOmniOperatorFactory( + 0, new PlanNodeId("test"), + ImmutableList.of(INTEGER, BIGINT, DOUBLE), Ints.asList(0, 1, 2), AVG, + Ints.asList(0), Ints.asList(), Ints.asList(1), + ImmutableList.copyOf(new SortOrder[]{SortOrder.ASC_NULLS_LAST}), 0, 10000); + + DriverContext driverContext = createDriverContext(); + MaterializedResult expected = resultBuilder(driverContext.getSession(), INTEGER, BIGINT, DOUBLE, DOUBLE) + .row(1, 2L, 0.3, 4.0).row(1, 4L, 0.2, 4.0).row(1, 6L, 0.1, 4.0).row(2, -1L, -0.1, 2.0) + .row(2, 5L, 0.4, 2.0).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapPages, expected, false); + } + + @Test + public void testSumPartitionWithoutSort() + { + List input = rowPagesBuilder(INTEGER, BIGINT, DOUBLE).row(2, -1L, -0.1).row(1, 2L, 0.3).row(1, 4L, 0.2) + .pageBreak().row(2, 5L, 0.4).row(1, 6L, 0.1).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + WindowOmniOperator.WindowOmniOperatorFactory operatorFactory = new WindowOmniOperator.WindowOmniOperatorFactory( + 0, new PlanNodeId("test"), + ImmutableList.of(INTEGER, BIGINT, DOUBLE), Ints.asList(0, 1, 2), SUM, + Ints.asList(0), Ints.asList(), Ints.asList(), ImmutableList.copyOf(new SortOrder[]{}), 0, 10000); + + DriverContext driverContext = createDriverContext(); + MaterializedResult expected = resultBuilder(driverContext.getSession(), INTEGER, BIGINT, DOUBLE, BIGINT) + .row(1, 2L, 0.3, 12L).row(1, 4L, 0.2, 12L).row(1, 6L, 0.1, 12L).row(2, -1L, -0.1, 4L) + .row(2, 5L, 0.4, 4L).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapPages, expected, false); + } + + @Test + public void testSumPartition() + { + List input = rowPagesBuilder(INTEGER, BIGINT, DOUBLE).row(2, -1L, -0.1).row(1, 2L, 0.3).row(1, 4L, 0.2) + .pageBreak().row(2, 5L, 0.4).row(1, 6L, 0.1).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + WindowOmniOperator.WindowOmniOperatorFactory operatorFactory = new WindowOmniOperator.WindowOmniOperatorFactory( + 0, new PlanNodeId("test"), + ImmutableList.of(INTEGER, BIGINT, DOUBLE), Ints.asList(0, 1, 2), SUM, + Ints.asList(0), Ints.asList(), Ints.asList(1), + ImmutableList.copyOf(new SortOrder[]{SortOrder.ASC_NULLS_LAST}), 0, 10000); + + DriverContext driverContext = createDriverContext(); + MaterializedResult expected = resultBuilder(driverContext.getSession(), INTEGER, BIGINT, DOUBLE, BIGINT) + .row(1, 2L, 0.3, 12L).row(1, 4L, 0.2, 12L).row(1, 6L, 0.1, 12L).row(2, -1L, -0.1, 4L) + .row(2, 5L, 0.4, 4L).build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapPages, expected, false); + } + + @Test + public void testMaxPartition() + { + List input = rowPagesBuilder(INTEGER, BIGINT, DOUBLE).row(2, -1L, -0.1).row(1, 2L, 0.3).row(1, 4L, 0.2) + .pageBreak().row(2, 5L, 0.4).row(1, 6L, 0.1).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + WindowOmniOperator.WindowOmniOperatorFactory operatorFactory = new WindowOmniOperator.WindowOmniOperatorFactory( + 0, new PlanNodeId("test"), + ImmutableList.of(INTEGER, BIGINT, DOUBLE), Ints.asList(0, 1, 2), MAX, + Ints.asList(0), Ints.asList(), Ints.asList(1), + ImmutableList.copyOf(new SortOrder[]{SortOrder.ASC_NULLS_LAST}), 0, 10000); + + DriverContext driverContext = createDriverContext(); + MaterializedResult expected = resultBuilder(driverContext.getSession(), INTEGER, BIGINT, DOUBLE, BIGINT) + .row(1, 2L, 0.3, 6L).row(1, 4L, 0.2, 6L).row(1, 6L, 0.1, 6L).row(2, -1L, -0.1, 5L).row(2, 5L, 0.4, 5L) + .build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapPages, expected, false); + } + + @Test + public void testMinPartition() + { + List input = rowPagesBuilder(INTEGER, BIGINT, DOUBLE).row(2, -1L, -0.1).row(1, 2L, 0.3).row(1, 4L, 0.2) + .pageBreak().row(2, 5L, 0.4).row(1, 6L, 0.1).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + WindowOmniOperator.WindowOmniOperatorFactory operatorFactory = new WindowOmniOperator.WindowOmniOperatorFactory( + 0, new PlanNodeId("test"), + ImmutableList.of(INTEGER, BIGINT, DOUBLE), Ints.asList(0, 1, 2), MIN, + Ints.asList(0), Ints.asList(), Ints.asList(1), + ImmutableList.copyOf(new SortOrder[]{SortOrder.ASC_NULLS_LAST}), 0, 10000); + + DriverContext driverContext = createDriverContext(); + MaterializedResult expected = resultBuilder(driverContext.getSession(), INTEGER, BIGINT, DOUBLE, BIGINT) + .row(1, 2L, 0.3, 2L).row(1, 4L, 0.2, 2L).row(1, 6L, 0.1, 2L).row(2, -1L, -0.1, -1L).row(2, 5L, 0.4, -1L) + .build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapPages, expected, false); + } + + @Test + public void testCountColumnPartition() + { + List input = rowPagesBuilder(INTEGER, BIGINT, DOUBLE).row(2, -1L, -0.1).row(1, 2L, 0.3).row(1, 4L, 0.2) + .pageBreak().row(2, 5L, 0.4).row(1, 6L, 0.1).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + WindowOmniOperator.WindowOmniOperatorFactory operatorFactory = new WindowOmniOperator.WindowOmniOperatorFactory( + 0, new PlanNodeId("test"), + ImmutableList.of(INTEGER, BIGINT, DOUBLE), Ints.asList(0, 1, 2), COUNT_COLUMN, + Ints.asList(0), Ints.asList(), Ints.asList(1), + ImmutableList.copyOf(new SortOrder[]{SortOrder.ASC_NULLS_LAST}), 0, 10000); + + DriverContext driverContext = createDriverContext(); + MaterializedResult expected = resultBuilder(driverContext.getSession(), INTEGER, BIGINT, DOUBLE, INTEGER) + .row(1, 2L, 0.3, 3).row(1, 4L, 0.2, 3).row(1, 6L, 0.1, 3).row(2, -1L, -0.1, 2).row(2, 5L, 0.4, 2) + .build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapPages, expected, false); + } + + @Test + public void testCountColumnPartitionWithoutSort() + { + List input = rowPagesBuilder(INTEGER, BIGINT, DOUBLE).row(2, -1L, -0.1).row(1, 2L, 0.3).row(1, 4L, 0.2) + .pageBreak().row(2, 5L, 0.4).row(1, 6L, 0.1).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + WindowOmniOperator.WindowOmniOperatorFactory operatorFactory = new WindowOmniOperator.WindowOmniOperatorFactory( + 0, new PlanNodeId("test"), ImmutableList.of(INTEGER, BIGINT, DOUBLE), Ints.asList(0, 1, 2), + COUNT_COLUMN, Ints.asList(0), Ints.asList(), Ints.asList(), ImmutableList.copyOf(new SortOrder[]{}), 0, + 10000); + + DriverContext driverContext = createDriverContext(); + MaterializedResult expected = resultBuilder(driverContext.getSession(), INTEGER, BIGINT, DOUBLE, INTEGER) + .row(1, 2L, 0.3, 3).row(1, 4L, 0.2, 3).row(1, 6L, 0.1, 3).row(2, -1L, -0.1, 2).row(2, 5L, 0.4, 2) + .build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapPages, expected, false); + } + + @Test + public void testCountAllPartition() + { + List input = rowPagesBuilder(INTEGER, BIGINT, DOUBLE).row(2, -1L, -0.1).row(1, 2L, 0.3).row(1, 4L, 0.2) + .pageBreak().row(2, 5L, 0.4).row(1, 6L, null).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + WindowOmniOperator.WindowOmniOperatorFactory operatorFactory = new WindowOmniOperator.WindowOmniOperatorFactory( + 0, new PlanNodeId("test"), + ImmutableList.of(INTEGER, BIGINT, DOUBLE), Ints.asList(0, 1, 2), COUNT_ALL, + Ints.asList(0), Ints.asList(), Ints.asList(1), + ImmutableList.copyOf(new SortOrder[]{SortOrder.ASC_NULLS_LAST}), 0, 10000); + + DriverContext driverContext = createDriverContext(); + MaterializedResult expected = resultBuilder(driverContext.getSession(), INTEGER, BIGINT, DOUBLE, INTEGER) + .row(1, 2L, 0.3, 3).row(1, 4L, 0.2, 3).row(1, 6L, null, 3).row(2, -1L, -0.1, 2).row(2, 5L, 0.4, 2) + .build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapPages, expected, false); + } + + @Test + public void testCountAllPartitionWithoutSort() + { + List input = rowPagesBuilder(INTEGER, BIGINT, DOUBLE).row(2, -1L, -0.1).row(1, 2L, 0.3).row(1, 4L, 0.2) + .pageBreak().row(2, 5L, 0.4).row(1, 6L, null).build(); + List offHeapPages = OperatorUtils.transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, input); + + WindowOmniOperator.WindowOmniOperatorFactory operatorFactory = new WindowOmniOperator.WindowOmniOperatorFactory( + 0, new PlanNodeId("test"), ImmutableList.of(INTEGER, BIGINT, DOUBLE), Ints.asList(0, 1, 2), COUNT_ALL, + Ints.asList(0), Ints.asList(), Ints.asList(), ImmutableList.copyOf(new SortOrder[]{}), 0, 10000); + + DriverContext driverContext = createDriverContext(); + MaterializedResult expected = resultBuilder(driverContext.getSession(), INTEGER, BIGINT, DOUBLE, INTEGER) + .row(1, 2L, 0.3, 3).row(1, 4L, 0.2, 3).row(1, 6L, null, 3).row(2, -1L, -0.1, 2).row(2, 5L, 0.4, 2) + .build(); + + assertOperatorEquals(operatorFactory, driverContext, offHeapPages, expected, false); + } + + private DriverContext createDriverContext() + { + return createDriverContext(Long.MAX_VALUE); + } + + private DriverContext createDriverContext(long memoryLimit) + { + return TestingTaskContext.builder(executor, scheduledExecutor, TEST_SESSION) + .setMemoryPoolSize(succinctBytes(memoryLimit)).build().addPipelineContext(0, true, true, false) + .addDriverContext(); + } +} diff --git a/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/tool/TestOperatorUtils.java b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/tool/TestOperatorUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..f094276ffc7e0259e8d575bd7f15bdf9770fe84b --- /dev/null +++ b/omnioperator/omniop-openlookeng-extension/src/test/omni/nova/hetu/olk/tool/TestOperatorUtils.java @@ -0,0 +1,280 @@ +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nova.hetu.olk.tool; + +import com.google.common.collect.ImmutableList; +import io.prestosql.spi.Page; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.StandardErrorCode; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.DictionaryBlock; +import io.prestosql.spi.block.RowBlock; +import io.prestosql.spi.type.RowType; +import io.prestosql.spi.type.Type; +import nova.hetu.olk.PageBuilderUtil; +import nova.hetu.olk.block.DictionaryOmniBlock; +import nova.hetu.olk.block.RowOmniBlock; +import nova.hetu.omniruntime.vector.BooleanVec; +import nova.hetu.omniruntime.vector.DecimalVec; +import nova.hetu.omniruntime.vector.DoubleVec; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.VecAllocator; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.DateType.DATE; +import static io.prestosql.spi.type.DecimalType.createDecimalType; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.RealType.REAL; +import static io.prestosql.spi.type.TimestampType.TIMESTAMP; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static nova.hetu.olk.tool.OperatorUtils.transferToOffHeapPages; +import static nova.hetu.olk.tool.OperatorUtils.transferToOnHeapPage; +import static nova.hetu.olk.tool.OperatorUtils.transferToOnHeapPages; +import static org.testng.Assert.assertEquals; + +public class TestOperatorUtils +{ + private List types = new ImmutableList.Builder().add(INTEGER).add(BIGINT).add(REAL).add(DOUBLE).add(VARCHAR) + .add(DATE).add(TIMESTAMP).add(BOOLEAN).add(createDecimalType(20, 10)).build(); + + @Test + public void testBasicTransfer() + { + List pages = buildPages(types, false, 100); + // transfer on-heap page to off-heap + List offHeapPages = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, pages); + + assertPagesEquals(types, pages, offHeapPages); + // transfer off-heap page to on-heap + List onHeapPages = transferToOnHeapPages(offHeapPages); + + assertPagesEquals(types, onHeapPages, offHeapPages); + freeNativeMemory(offHeapPages); + } + + @Test + public void testDictionaryTransfer() + { + List pages = buildPages(types, true, 100); + // transfer on-heap page to off-heap + List offHeapPages = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, pages); + + assertPagesEquals(types, pages, offHeapPages); + // transfer off-heap page to on-heap + List onHeapPages = transferToOnHeapPages(offHeapPages); + + assertPagesEquals(types, onHeapPages, offHeapPages); + freeNativeMemory(offHeapPages); + } + + @Test + public void testRowBlockTransfer() + { + Type type = BIGINT; + Page page = new Page(buildRowBlockByBuilder(type)); + + // transfer on-heap page to off-heap + Page offHeapPage = transferToOffHeapPages(VecAllocator.GLOBAL_VECTOR_ALLOCATOR, page, + ImmutableList.of(RowType.anonymous(ImmutableList.of(type)))); + + assertPageEquals(type, page, offHeapPage); + // transfer off-heap page to on-heap + Page onHeapPage = transferToOnHeapPage(offHeapPage); + + assertPageEquals(type, onHeapPage, offHeapPage); + BlockUtils.freePage(offHeapPage); + } + + public static void freeNativeMemory(List offHeapPages) + { + for (Page page : offHeapPages) { + BlockUtils.freePage(page); + } + } + + public static void assertPagesEquals(List types, List actual, List expected) + { + for (int i = 0; i < actual.size(); i++) { + assertPageEquals(types.get(i), actual.get(i), expected.get(i)); + } + } + + public static void assertPageEquals(Type type, Page actual, Page expected) + { + for (int i = 0; i < actual.getChannelCount(); i++) { + assertBlockEquals(type, actual.getBlock(i), expected.getBlock(i)); + } + } + + private static void assertBlockEquals(Type type, Block actual, Block expected) + { + if (actual.isExtensionBlock() && expected.isExtensionBlock()) { + assertOmniBlockEquals(actual, expected); + } + else if (actual.isExtensionBlock() || expected.isExtensionBlock()) { + if (actual.isExtensionBlock()) { + assertOlkAndOmniBlockEquals(expected, actual); + } + else { + assertOlkAndOmniBlockEquals(actual, expected); + } + } + else { + assertOlkBlockEquals(type, actual, expected); + } + } + + private static void assertOmniBlockEquals(Block actual, Block expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(actual.get(position), expected.get(position)); + } + } + + private static void assertOlkBlockEquals(Type type, Block actual, Block expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(actual.get(position), expected.get(position)); + } + } + + private static void assertOlkAndOmniBlockEquals(Block block, Block omniBlock) + { + Object blockValues = omniBlock.getValues(); + switch (block.getClass().getSimpleName()) { + case "ByteArrayBlock": + assertBlockEquals(block, (BooleanVec) blockValues); + return; + case "IntArrayBlock": + assertBlockEquals(block, (IntVec) blockValues); + return; + case "LongArrayBlock": + assertBlockEquals(block, (LongVec) blockValues); + return; + case "DoubleArrayBlock": + assertBlockEquals(block, (DoubleVec) blockValues); + return; + case "Int128ArrayBlock": + assertBlockEquals(block, (DecimalVec) blockValues); + return; + case "VariableWidthBlock": + assertBlockEquals(block, (VarcharVec) blockValues); + return; + case "DictionaryBlock": + assertOlkAndOmniBlockEquals(((DictionaryBlock) block).getDictionary(), ((DictionaryOmniBlock) omniBlock).getDictionary()); + return; + case "RowBlock": + assertRowBlockEquals((RowBlock) block, (RowOmniBlock) omniBlock); + return; + default: + throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, + "Not support block:" + block.getClass().getSimpleName()); + } + } + + private static void assertRowBlockEquals(RowBlock actual, RowOmniBlock expected) + { + Block[] rawFieldBlocks = actual.getRawFieldBlocks(); + Block[] expectedRawFieldBlocks = expected.getRawFieldBlocks(); + assertEquals(rawFieldBlocks.length, expectedRawFieldBlocks.length); + for (int i = 0; i < rawFieldBlocks.length; i++) { + assertEquals(rawFieldBlocks[i].getPositionCount(), expectedRawFieldBlocks[i].getPositionCount()); + assertOlkAndOmniBlockEquals((Block) rawFieldBlocks[i], (Block) expectedRawFieldBlocks[i]); + } + } + + private static void assertBlockEquals(Block actual, DecimalVec expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(actual.get(position), expected.get(position)); + } + } + + private static void assertBlockEquals(Block actual, VarcharVec expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals(new String((byte[]) actual.get(position)), new String(expected.get(position))); + } + } + + private static void assertBlockEquals(Block actual, DoubleVec expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals((Double) actual.get(position), new Double(expected.get(position))); + } + } + + private static void assertBlockEquals(Block actual, LongVec expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals((Long) actual.get(position), new Long(expected.get(position))); + } + } + + private static void assertBlockEquals(Block actual, IntVec expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals((Integer) actual.get(position), new Integer(expected.get(position))); + } + } + + private static void assertBlockEquals(Block actual, BooleanVec expected) + { + for (int position = 0; position < actual.getPositionCount(); position++) { + assertEquals((actual.get(position)), (expected.get(position)) ? (byte) 1 : (byte) 0); + } + } + + public static List buildPages(List typesArray, boolean dictionaryBlocks, int rows) + { + List pages = new ArrayList<>(); + for (int i = 0; i < typesArray.size(); i++) { + if (dictionaryBlocks) { + pages.add(PageBuilderUtil.createSequencePageWithDictionaryBlocks(typesArray, rows)); + } + else { + pages.add(PageBuilderUtil.createSequencePage(typesArray, rows)); + } + } + return pages; + } + + public static Block buildRowBlockByBuilder(Type type) + { + BlockBuilder rowBlockBuilder = type.createBlockBuilder(null, 4); + type.writeLong(rowBlockBuilder, 1); + type.writeLong(rowBlockBuilder, 10); + type.writeLong(rowBlockBuilder, 100); + type.writeLong(rowBlockBuilder, 1000); + Block block = rowBlockBuilder.build(); + + boolean[] rowIsNull = new boolean[1]; + int[] fieldBlockOffsets = {0, 1}; + Block[] blocks = new Block[1]; + blocks[0] = block; + return RowBlock.fromFieldBlocks(1, Optional.of(rowIsNull), blocks); + } +}