diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp index 4265d4c03e50949328142f1608e3f82a53b1f8ba..199247217475b4f1a5407f253add457578d2ae84 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp @@ -24,6 +24,29 @@ using namespace omniruntime::vec; SplitOptions SplitOptions::Defaults() { return SplitOptions(); } +void Splitter::BuildPartition2Row(int32_t num_rows) +{ + row_offset_row_id_.resize(num_rows); + partition_row_offset_base_.resize(num_partitions_ + 1); + partition_row_offset_base_[0] = 0; + for (auto pid = 1; pid <= num_partitions_; ++pid) { + partition_row_offset_base_[pid] = partition_row_offset_base_[pid - 1] + partition_id_cnt_cur_[pid - 1]; + } + for (auto row = 0; row < num_rows; ++row) { + auto pid = partition_id_[row]; + row_offset_row_id_[partition_row_offset_base_[pid]++] = row; + } + for (auto pid = 0; pid < num_partitions_; ++pid) { + partition_row_offset_base_[pid] -= partition_id_cnt_cur_[pid]; + } + partition_used_.clear(); + for (auto pid = 0; pid != num_partitions_; ++pid) { + if (partition_id_cnt_cur_[pid] > 0) { + partition_used_.push_back(pid); + } + } +} + // 计算分区id,每个batch初始化 int Splitter::ComputeAndCountPartitionId(VectorBatch& vb) { auto num_rows = vb.GetRowCount(); @@ -130,46 +153,38 @@ int Splitter::SplitFixedWidthValueBuffer(VectorBatch& vb) { if (vb.Get(col_idx_vb)->GetEncoding() == OMNI_DICTIONARY) { LogsDebug("Dictionary Columnar process!"); - auto ids_addr = VectorHelper::UnsafeGetValues(vb.Get(col_idx_vb)); + auto ids_addr = static_cast(VectorHelper::UnsafeGetValues(vb.Get(col_idx_vb))); auto src_addr = reinterpret_cast(VectorHelper::UnsafeGetDictionary(vb.Get(col_idx_vb))); + auto process = [&](const ShuffleTypeId shuffleTypeId) { + const auto shuffle_size = (1 << shuffleTypeId); + for (auto &pid: partition_used_) { + auto dstPidBase = reinterpret_cast(dst_addrs[pid]) + partition_buffer_idx_base_[pid]; + auto pos = partition_row_offset_base_[pid]; + auto end = partition_row_offset_base_[pid + 1]; + auto count = end - pos; + for (; pos < end; ++pos) { + auto rowId = row_offset_row_id_[pos]; + *dstPidBase++ = reinterpret_cast(src_addr)[ids_addr[rowId]]; + } + partition_fixed_width_buffers_[col][pid][1]->size_ += shuffle_size * count; + partition_buffer_idx_offset_[pid] += count; + } + }; switch (column_type_id_[col_idx_schema]) { -#define PROCESS(SHUFFLE_TYPE, CTYPE) \ - case SHUFFLE_TYPE: \ - { \ - auto shuffle_size = (1 << SHUFFLE_TYPE); \ - for (auto row = 0; row < num_rows; ++row) { \ - auto pid = partition_id_[row]; \ - auto dst_offset = \ - partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; \ - reinterpret_cast(dst_addrs[pid])[dst_offset] = \ - reinterpret_cast(src_addr)[reinterpret_cast(ids_addr)[row]]; \ - partition_fixed_width_buffers_[col][pid][1]->size_ += shuffle_size; \ - partition_buffer_idx_offset_[pid]++; \ - } \ - } \ - break; - PROCESS(SHUFFLE_1BYTE, uint8_t) - PROCESS(SHUFFLE_2BYTE, uint16_t) - PROCESS(SHUFFLE_4BYTE, uint32_t) - PROCESS(SHUFFLE_8BYTE, uint64_t) -#undef PROCESS + case SHUFFLE_1BYTE: + process.operator()(SHUFFLE_1BYTE); + break; + case SHUFFLE_2BYTE: + process.operator()(SHUFFLE_2BYTE); + break; + case SHUFFLE_4BYTE: + process.operator()(SHUFFLE_4BYTE); + break; + case SHUFFLE_8BYTE: + process.operator()(SHUFFLE_8BYTE); + break; case SHUFFLE_DECIMAL128: - { - auto shuffle_size = (1 << SHUFFLE_DECIMAL128); - for (auto row = 0; row < num_rows; ++row) { - auto pid = partition_id_[row]; - auto dst_offset = - partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; - // 前64位取值、赋值 - reinterpret_cast(dst_addrs[pid])[dst_offset << 1] = - reinterpret_cast(src_addr)[reinterpret_cast(ids_addr)[row] << 1]; - // 后64位取值、赋值 - reinterpret_cast(dst_addrs[pid])[(dst_offset << 1) | 1] = - reinterpret_cast(src_addr)[(reinterpret_cast(ids_addr)[row] << 1) | 1]; - partition_fixed_width_buffers_[col][pid][1]->size_ += shuffle_size; //decimal128 16Bytes - partition_buffer_idx_offset_[pid]++; - } - } + process.operator()(SHUFFLE_DECIMAL128); break; default: { LogsError("SplitFixedWidthValueBuffer not match this type: %d", column_type_id_[col_idx_schema]); @@ -178,42 +193,37 @@ int Splitter::SplitFixedWidthValueBuffer(VectorBatch& vb) { } } else { auto src_addr = reinterpret_cast(VectorHelper::UnsafeGetValues(vb.Get(col_idx_vb))); + auto process = [&](const ShuffleTypeId shuffleTypeId) { + const auto shuffle_size = (1 << shuffleTypeId); + for (auto &pid: partition_used_) { + auto dst_offset = partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; + auto dstPidBase = reinterpret_cast(dst_addrs[pid]) + dst_offset; + auto pos = partition_row_offset_base_[pid]; + auto end = partition_row_offset_base_[pid + 1]; + auto count = end - pos; + for (; pos < end; ++pos) { + auto rowId = row_offset_row_id_[pos]; + *dstPidBase++ = reinterpret_cast(src_addr)[rowId]; + } + partition_fixed_width_buffers_[col][pid][1]->size_ += shuffle_size * count; + partition_buffer_idx_offset_[pid] += count; + } + }; switch (column_type_id_[col_idx_schema]) { -#define PROCESS(SHUFFLE_TYPE, CTYPE) \ - case SHUFFLE_TYPE: \ - { \ - auto shuffle_size = (1 << SHUFFLE_TYPE); \ - for (auto row = 0; row < num_rows; ++row) { \ - auto pid = partition_id_[row]; \ - auto dst_offset = \ - partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; \ - reinterpret_cast(dst_addrs[pid])[dst_offset] = \ - reinterpret_cast(src_addr)[row]; \ - partition_fixed_width_buffers_[col][pid][1]->size_ += shuffle_size; \ - partition_buffer_idx_offset_[pid]++; \ - } \ - } \ - break; - PROCESS(SHUFFLE_1BYTE, uint8_t) - PROCESS(SHUFFLE_2BYTE, uint16_t) - PROCESS(SHUFFLE_4BYTE, uint32_t) - PROCESS(SHUFFLE_8BYTE, uint64_t) -#undef PROCESS + case SHUFFLE_1BYTE: + process.operator()(SHUFFLE_1BYTE); + break; + case SHUFFLE_2BYTE: + process.operator()(SHUFFLE_2BYTE); + break; + case SHUFFLE_4BYTE: + process.operator()(SHUFFLE_4BYTE); + break; + case SHUFFLE_8BYTE: + process.operator()(SHUFFLE_8BYTE); + break; case SHUFFLE_DECIMAL128: - { - auto shuffle_size = (1 << SHUFFLE_DECIMAL128); - for (auto row = 0; row < num_rows; ++row) { - auto pid = partition_id_[row]; - auto dst_offset = - partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; - reinterpret_cast(dst_addrs[pid])[dst_offset << 1] = - reinterpret_cast(src_addr)[row << 1]; // 前64位取值、赋值 - reinterpret_cast(dst_addrs[pid])[(dst_offset << 1) | 1] = - reinterpret_cast(src_addr)[(row << 1) | 1]; // 后64位取值、赋值 - partition_fixed_width_buffers_[col][pid][1]->size_ += shuffle_size; //decimal128 16Bytes - partition_buffer_idx_offset_[pid]++; - } - } + process.operator()(SHUFFLE_DECIMAL128); break; default: { LogsError("ERROR: SplitFixedWidthValueBuffer not match this type: %d", column_type_id_[col_idx_schema]); @@ -449,6 +459,8 @@ int Splitter::DoSplit(VectorBatch& vb) { } } } + BuildPartition2Row(vb.GetRowCount()); + SplitFixedWidthValueBuffer(vb); SplitFixedWidthValidityBuffer(vb); diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h index 9f0e8fa582aa30e7ef6933546367854e50fa986c..d30eccf4d6de7e55b7722d9eec6420975c000c72 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h @@ -126,6 +126,9 @@ class Splitter { std::vector partition_id_; // 记录当前vb每一行的pid int32_t *partition_id_cnt_cur_; // 统计不同partition记录的行数(当前处理中的vb) uint64_t *partition_id_cnt_cache_; // 统计不同partition记录的行数,cache住的 + std::vector row_offset_row_id_; + std::vector partition_used_; + std::vector partition_row_offset_base_; // column number uint32_t num_row_splited_; // cached row number uint64_t cached_vectorbatch_size_; // cache total vectorbatch size in bytes @@ -160,6 +163,8 @@ class Splitter { spark::ProtoRowBatch *protoRowBatch = new ProtoRowBatch(); private: + void BuildPartition2Row(int32_t row_count); + void ReleaseVarcharVector() { std::set::iterator it;