Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ private class GlutenOptimizedWriterShuffleReader(
case _ =>
SparkEnv.get.serializerManager
}
val wrappedStreams = new ShuffleBlockFetcherIterator(
val wrappedStreams = new GlutenShuffleBlockFetcherIterator(
context,
SparkEnv.get.blockManager.blockStoreClient,
SparkEnv.get.blockManager,
Expand All @@ -335,7 +335,7 @@ private class GlutenOptimizedWriterShuffleReader(
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM),
readMetrics,
false
).toCompletionIterator
)

// Create a key/value iterator for each stream
val recordIter = dep match {
Expand All @@ -344,12 +344,12 @@ private class GlutenOptimizedWriterShuffleReader(
columnarDep.serializer
.newInstance()
.asInstanceOf[ColumnarBatchSerializerInstance]
.deserializeStreams(wrappedStreams)
.deserializeStreams(wrappedStreams, wrappedStreams.cleanup)
.asKeyValueIterator
case _ =>
val serializerInstance = dep.serializer.newInstance()
// Create a key/value iterator for each stream
wrappedStreams.flatMap {
wrappedStreams.toCompletionIterator.flatMap {
case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,20 @@ private class ColumnarBatchSerializerInstanceImpl(
shuffleReaderHandle
}

// TODO: remove this method for columnar shuffle.
override def deserializeStream(in: InputStream): DeserializationStream = {
new TaskDeserializationStream(Iterator((null, in)))
}

override def deserializeStreams(
streams: Iterator[(BlockId, InputStream)]): DeserializationStream = {
new TaskDeserializationStream(streams)
streams: Iterator[(BlockId, InputStream)],
completionFunction: () => Unit): DeserializationStream = {
new TaskDeserializationStream(streams, Some(completionFunction))
}

private class TaskDeserializationStream(streams: Iterator[(BlockId, InputStream)])
private class TaskDeserializationStream(
streams: Iterator[(BlockId, InputStream)],
completionFunction: Option[() => Unit] = None)
extends DeserializationStream
with TaskResource {
private val streamReader = ShuffleStreamReader(streams)
Expand Down Expand Up @@ -219,6 +223,9 @@ private class ColumnarBatchSerializerInstanceImpl(
if (!closeCalled.compareAndSet(false, true)) {
return
}
// Stop reading more streams. Blocked by the native reader threads.
jniWrapper.stop(shuffleReaderHandle)
completionFunction.foreach(_())
// Would remove the resource object from registry to lower GC pressure.
TaskResources.releaseResource(resourceId)
}
Expand All @@ -242,7 +249,6 @@ private class ColumnarBatchSerializerInstanceImpl(
}
numOutputRows += numRowsTotal
wrappedOut.close()
streamReader.close()
if (cb != null) {
cb.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ import scala.reflect.ClassTag
abstract class ColumnarBatchSerializerInstance extends SerializerInstance {

/** Deserialize the streams of ColumnarBatches. */
def deserializeStreams(streams: Iterator[(BlockId, InputStream)]): DeserializationStream
def deserializeStreams(
streams: Iterator[(BlockId, InputStream)],
completionFunction: () => Unit): DeserializationStream

override def serialize[T: ClassTag](t: T): ByteBuffer = {
throw new UnsupportedOperationException
Expand All @@ -44,4 +46,8 @@ abstract class ColumnarBatchSerializerInstance extends SerializerInstance {
override def serializeStream(s: OutputStream): SerializationStream = {
throw new UnsupportedOperationException
}

override def deserializeStream(s: InputStream): DeserializationStream = {
throw new UnsupportedOperationException
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark._
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator}
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, GlutenShuffleBlockFetcherIterator}
import org.apache.spark.util.CompletionIterator

/**
Expand Down Expand Up @@ -70,7 +70,7 @@ class ColumnarShuffleReader[K, C](

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val wrappedStreams = new ShuffleBlockFetcherIterator(
val shuffleBlockFetcherIterator = new GlutenShuffleBlockFetcherIterator(
context,
blockManager.blockStoreClient,
blockManager,
Expand All @@ -89,20 +89,22 @@ class ColumnarShuffleReader[K, C](
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM),
readMetrics,
fetchContinuousBlocksInBatch
).toCompletionIterator
)

val recordIter = dep match {
case columnarDep: ColumnarShuffleDependency[K, _, C] =>
// If the dependency is a ColumnarShuffleDependency, we use the columnar serializer.
columnarDep.serializer
.newInstance()
.asInstanceOf[ColumnarBatchSerializerInstance]
.deserializeStreams(wrappedStreams)
.deserializeStreams(
shuffleBlockFetcherIterator,
shuffleBlockFetcherIterator.cleanup)
.asKeyValueIterator
case _ =>
val serializerInstance = dep.serializer.newInstance()
// Create a key/value iterator for each stream
wrappedStreams.flatMap {
shuffleBlockFetcherIterator.toCompletionIterator.flatMap {
case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
Expand Down
2 changes: 2 additions & 0 deletions cpp/core/config/GlutenConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ constexpr bool kCudfEnabledDefault = false;
const std::string kDebugCudf = "spark.gluten.sql.debug.cudf";
const std::string kDebugCudfDefault = "false";

const std::string kShuffleReaderThreads = "spark.gluten.sql.columnar.shuffle.numReaderThreads";

std::unordered_map<std::string, std::string>
parseConfMap(JNIEnv* env, const uint8_t* planData, const int32_t planDataLength);

Expand Down
10 changes: 10 additions & 0 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,16 @@ JNIEXPORT void JNICALL Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper
JNI_METHOD_END()
}

JNIEXPORT void JNICALL Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper_stop( // NOLINT
JNIEnv* env,
jobject wrapper,
jlong shuffleReaderHandle) {
JNI_METHOD_START
auto reader = ObjectStore::retrieve<ShuffleReader>(shuffleReaderHandle);
reader->stop();
JNI_METHOD_END()
}

JNIEXPORT void JNICALL Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper_close( // NOLINT
JNIEnv* env,
jobject wrapper,
Expand Down
4 changes: 4 additions & 0 deletions cpp/core/shuffle/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include <arrow/ipc/options.h>
#include <arrow/util/compression.h>
#include <thread>

namespace gluten {

Expand Down Expand Up @@ -68,6 +69,9 @@ struct ShuffleReaderOptions {
// Whether to enable the reader-side raw payload merge fast path for plain hash shuffle payloads within one input
// stream.
bool enableHashShuffleReaderStreamMerge = false;

// Thread number for async shuffle read.
int32_t numReaderThreads = std::thread::hardware_concurrency();
};

struct ShuffleWriterOptions {
Expand Down
2 changes: 2 additions & 0 deletions cpp/core/shuffle/ShuffleReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class ShuffleReader {
virtual int64_t getDecompressTime() const = 0;

virtual int64_t getDeserializeTime() const = 0;

virtual void stop() = 0;
};

} // namespace gluten
1 change: 1 addition & 0 deletions cpp/velox/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ set(VELOX_SRCS
operators/writer/VeloxColumnarBatchWriter.cc
operators/writer/VeloxParquetDataSource.cc
shuffle/ArrowShuffleDictionaryWriter.cc
shuffle/ReaderThreadPool.cc
shuffle/VeloxHashShuffleWriter.cc
shuffle/VeloxRssSortShuffleWriter.cc
shuffle/VeloxShuffleReader.cc
Expand Down
9 changes: 9 additions & 0 deletions cpp/velox/compute/VeloxBackend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ void VeloxBackend::init(
velox::exec::Operator::registerOperator(std::make_unique<CudfVectorStreamOperatorTranslator>());
velox::cudf_velox::registerSparkFunctions("");
velox::cudf_velox::registerSparkAggregateFunctions("");
readerThreadPool_ = std::make_unique<ReaderThreadPool>(
backendConf_->get<int32_t>(kShuffleReaderThreads, std::thread::hardware_concurrency()));
}
#endif

Expand Down Expand Up @@ -294,12 +296,19 @@ void VeloxBackend::init(
registerShuffleDictionaryWriterFactory([](MemoryManager* memoryManager, arrow::util::Codec* codec) {
return std::make_unique<ArrowShuffleDictionaryWriter>(memoryManager, codec);
});

readerThreadPool_ = std::make_unique<ReaderThreadPool>(
backendConf_->get<int32_t>(kShuffleReaderThreads, std::thread::hardware_concurrency()));
}

facebook::velox::cache::AsyncDataCache* VeloxBackend::getAsyncDataCache() const {
return asyncDataCache_.get();
}

ReaderThreadPool* VeloxBackend::getReaderThreadPool() const {
return readerThreadPool_.get();
}

// JNI-or-local filesystem, for spilling-to-heap if we have extra JVM heap spaces
void VeloxBackend::initJolFilesystem() {
int64_t maxSpillFileSize = backendConf_->get<int64_t>(kMaxSpillFileSize, kMaxSpillFileSizeDefault);
Expand Down
5 changes: 5 additions & 0 deletions cpp/velox/compute/VeloxBackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include "jni/JniHashTable.h"
#include "memory/VeloxMemoryManager.h"
#include "shuffle/ReaderThreadPool.h"

namespace gluten {

Expand All @@ -50,6 +51,8 @@ class VeloxBackend {

facebook::velox::cache::AsyncDataCache* getAsyncDataCache() const;

ReaderThreadPool* getReaderThreadPool() const;

std::shared_ptr<facebook::velox::config::ConfigBase> getBackendConf() const {
return backendConf_;
}
Expand Down Expand Up @@ -126,6 +129,8 @@ class VeloxBackend {
std::string cacheFilePrefix_;

std::shared_ptr<facebook::velox::config::ConfigBase> backendConf_;

std::unique_ptr<ReaderThreadPool> readerThreadPool_;
};

} // namespace gluten
4 changes: 1 addition & 3 deletions cpp/velox/compute/VeloxRuntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ std::shared_ptr<ShuffleReader> VeloxRuntime::createShuffleReader(
const auto veloxCompressionKind = arrowCompressionTypeToVelox(options.compressionType);
const auto rowType = facebook::velox::asRowType(gluten::fromArrowSchema(schema));

auto deserializerFactory = std::make_unique<gluten::VeloxShuffleReaderDeserializerFactory>(
return std::make_shared<gluten::VeloxShuffleReader>(
schema,
std::move(codec),
veloxCompressionKind,
Expand All @@ -617,8 +617,6 @@ std::shared_ptr<ShuffleReader> VeloxRuntime::createShuffleReader(
memoryManager(),
options.shuffleWriterType,
options.enableHashShuffleReaderStreamMerge);

return std::make_shared<VeloxShuffleReader>(std::move(deserializerFactory));
}

std::unique_ptr<ColumnarBatchSerializer> VeloxRuntime::createColumnarBatchSerializer(struct ArrowSchema* cSchema) {
Expand Down
99 changes: 99 additions & 0 deletions cpp/velox/shuffle/ReaderThreadPool.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "shuffle/ReaderThreadPool.h"
#include <glog/logging.h>

namespace gluten {

ReaderThreadPool::ReaderThreadPool(size_t numThreads) : numThreads_(numThreads) {
workers_.reserve(numThreads);
for (size_t i = 0; i < numThreads; ++i) {
workers_.emplace_back([this]() { workerThread(); });
}
LOG(WARNING) << "Created ReaderThreadPool with " << numThreads << " threads.";
}

ReaderThreadPool::~ReaderThreadPool() {
shutdown();
}

void ReaderThreadPool::submitBatch(std::vector<Task> tasks, int32_t priority) {
std::lock_guard<std::mutex> lock(taskQueueMtx_);
if (stop_.load(std::memory_order_acquire)) {
return;
}
for (auto& task : tasks) {
tasks_.push({std::move(task), priority});
}
}

void ReaderThreadPool::start() {
// Wake up all worker threads to start processing.
wakeUpCV_.notify_all();
LOG(WARNING) << "Started ReaderThreadPool execution.";
}

void ReaderThreadPool::shutdown() {
if (!isShutdown()) {
stop_.store(true, std::memory_order_release);
wakeUpCV_.notify_all();

// Wait for all worker threads to finish their current tasks and join.
for (auto& worker : workers_) {
if (worker.joinable()) {
worker.join();
}
}
}
}

void ReaderThreadPool::workerThread() {
while (true) {
{
std::unique_lock<std::mutex> lock(taskQueueMtx_);

wakeUpCV_.wait(lock, [this]() { return stop_.load(std::memory_order_acquire) || !tasks_.empty(); });

if (stop_.load(std::memory_order_acquire)) {
// Discard remaining tasks and exit the thread.
return;
}
}

while (true) {
Task task;
{
std::lock_guard<std::mutex> lock(taskQueueMtx_);
if (tasks_.empty()) {
break;
}
auto& prioritizedTask = tasks_.top();
LOG(WARNING) << "Worker thread " << std::this_thread::get_id() << " is executing a task with priority "
<< prioritizedTask.priority;
task = std::move(prioritizedTask.task);
tasks_.pop();
}

if (task) {
task();
}
}
}
}

} // namespace gluten
Loading
Loading