Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Posts CoreWorkerMemoryStore callbacks onto io_context. #47833

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cpp/src/ray/runtime/object/local_mode_object_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
namespace ray {
namespace internal {
LocalModeObjectStore::LocalModeObjectStore(LocalModeRayRuntime &local_mode_ray_tuntime)
: local_mode_ray_tuntime_(local_mode_ray_tuntime) {
memory_store_ = std::make_unique<CoreWorkerMemoryStore>();
: io_context_("LocalModeObjectStore"),
local_mode_ray_tuntime_(local_mode_ray_tuntime) {
memory_store_ = std::make_unique<CoreWorkerMemoryStore>(&io_context_.GetIoService());
}

void LocalModeObjectStore::PutRaw(std::shared_ptr<msgpack::sbuffer> data,
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/ray/runtime/object/local_mode_object_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "../local_mode_ray_runtime.h"
#include "object_store.h"
#include "ray/common/asio/asio_util.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h"

namespace ray {
Expand All @@ -27,7 +28,7 @@ using ray::core::CoreWorkerMemoryStore;

class LocalModeObjectStore : public ObjectStore {
public:
LocalModeObjectStore(LocalModeRayRuntime &local_mode_ray_tuntime);
explicit LocalModeObjectStore(LocalModeRayRuntime &local_mode_ray_tuntime);

std::vector<bool> Wait(const std::vector<ObjectID> &ids,
int num_objects,
Expand All @@ -47,6 +48,7 @@ class LocalModeObjectStore : public ObjectStore {
std::vector<std::shared_ptr<msgpack::sbuffer>> GetRaw(const std::vector<ObjectID> &ids,
int timeout_ms);

InstrumentedIOContextWithThread io_context_;
std::unique_ptr<CoreWorkerMemoryStore> memory_store_;

LocalModeRayRuntime &local_mode_ray_tuntime_;
Expand Down
1 change: 1 addition & 0 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
options_.worker_type != WorkerType::RESTORE_WORKER),
/*get_current_call_site=*/boost::bind(&CoreWorker::CurrentCallSite, this)));
memory_store_.reset(new CoreWorkerMemoryStore(
&io_service_,
reference_counter_,
local_raylet_client_,
options_.check_signals,
Expand Down
38 changes: 28 additions & 10 deletions src/ray/core_worker/store_provider/memory_store/memory_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <condition_variable>
#include <utility>

#include "ray/common/ray_config.h"
#include "ray/core_worker/context.h"
Expand Down Expand Up @@ -148,17 +149,29 @@ std::shared_ptr<RayObject> GetRequest::Get(const ObjectID &object_id) const {
}

CoreWorkerMemoryStore::CoreWorkerMemoryStore(
instrumented_io_context *io_context,
std::shared_ptr<ReferenceCounter> counter,
std::shared_ptr<raylet::RayletClient> raylet_client,
std::function<Status()> check_signals,
std::function<void(const RayObject &)> unhandled_exception_handler,
std::function<std::shared_ptr<ray::RayObject>(
const ray::RayObject &object, const ObjectID &object_id)> object_allocator)
: ref_counter_(std::move(counter)),
raylet_client_(raylet_client),
check_signals_(check_signals),
unhandled_exception_handler_(unhandled_exception_handler),
object_allocator_(std::move(object_allocator)) {}
: owned_io_context_with_thread_(
io_context == nullptr ? std::make_unique<InstrumentedIOContextWithThread>(
"TestOnly.CoreWorkerMemoryStore")
: nullptr),
io_context_(io_context == nullptr ? owned_io_context_with_thread_->GetIoService()
: *io_context),
ref_counter_(std::move(counter)),
raylet_client_(std::move(raylet_client)),
check_signals_(std::move(check_signals)),
unhandled_exception_handler_(std::move(unhandled_exception_handler)),
object_allocator_(std::move(object_allocator)) {
if (owned_io_context_with_thread_ != nullptr) {
RAY_LOG(WARNING) << "io_context not provided to CoreWorkerMemoryStore! This should "
"only happen in cpp tests.";
}
}

void CoreWorkerMemoryStore::GetAsync(
const ObjectID &object_id, std::function<void(std::shared_ptr<RayObject>)> callback) {
Expand All @@ -177,7 +190,8 @@ void CoreWorkerMemoryStore::GetAsync(
}
// It's important for performance to run the callback outside the lock.
if (ptr != nullptr) {
callback(ptr);
io_context_.post([callback, ptr]() { callback(ptr); },
"CoreWorkerMemoryStore.GetAsync");
}
}

Expand All @@ -198,7 +212,7 @@ std::shared_ptr<RayObject> CoreWorkerMemoryStore::GetIfExists(const ObjectID &ob

bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_id) {
std::vector<std::function<void(std::shared_ptr<RayObject>)>> async_callbacks;
RAY_LOG(DEBUG) << "Putting object into memory store. objectid is " << object_id;
RAY_LOG(DEBUG).WithField(object_id) << "Putting object into memory store.";
std::shared_ptr<RayObject> object_entry = nullptr;
if (object_allocator_ != nullptr) {
object_entry = object_allocator_(object, object_id);
Expand Down Expand Up @@ -256,9 +270,13 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_
}

// It's important for performance to run the callbacks outside the lock.
for (const auto &cb : async_callbacks) {
cb(object_entry);
}
io_context_.post(
[async_callbacks, object_entry]() {
for (const auto &cb : async_callbacks) {
cb(object_entry);
}
},
"CoreWorkerMemoryStore.Put.async_callbacks");

return true;
}
Expand Down
16 changes: 12 additions & 4 deletions src/ray/core_worker/store_provider/memory_store/memory_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/synchronization/mutex.h"
#include "ray/common/asio/asio_util.h"
#include "ray/common/id.h"
#include "ray/common/status.h"
#include "ray/core_worker/common.h"
#include "ray/core_worker/context.h"
#include "ray/core_worker/reference_count.h"

Expand All @@ -44,20 +44,24 @@ class CoreWorkerMemoryStore {
public:
/// Create a memory store.
///
/// \param[in] io_context Posts async callbacks to this context. TESTONLY: if nullptr,
/// creates an owned dedicated thread and uses that context.
/// \param[in] counter If not null, this enables ref counting for local objects,
/// and the `remove_after_get` flag for Get() will be ignored.
/// \param[in] raylet_client If not null, used to notify tasks blocked / unblocked.
CoreWorkerMemoryStore(
explicit CoreWorkerMemoryStore(
instrumented_io_context *io_context = nullptr,
std::shared_ptr<ReferenceCounter> counter = nullptr,
std::shared_ptr<raylet::RayletClient> raylet_client = nullptr,
std::function<Status()> check_signals = nullptr,
std::function<void(const RayObject &)> unhandled_exception_handler = nullptr,
std::function<std::shared_ptr<RayObject>(const RayObject &object,
const ObjectID &object_id)>
object_allocator = nullptr);
~CoreWorkerMemoryStore(){};
~CoreWorkerMemoryStore() = default;

/// Put an object with specified ID into object store.
/// Put an object with specified ID into object store. If there are pending GetAsync
/// requests, the callbacks are posted onto the io_context.
///
/// \param[in] object The ray object.
/// \param[in] object_id Object ID specified by user.
Expand Down Expand Up @@ -193,6 +197,10 @@ class CoreWorkerMemoryStore {
void EraseObjectAndUpdateStats(const ObjectID &object_id)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);

// Only created if ctor does not provide an io_context.
std::unique_ptr<InstrumentedIOContextWithThread> owned_io_context_with_thread_;
instrumented_io_context &io_context_;

/// If enabled, holds a reference to local worker ref counter. TODO(ekl) make this
/// mandatory once Java is supported.
std::shared_ptr<ReferenceCounter> ref_counter_ = nullptr;
Expand Down
3 changes: 2 additions & 1 deletion src/ray/core_worker/task_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ bool TaskManager::HandleTaskReturn(const ObjectID &object_id,
rpc::Address owner_address;
if (reference_counter_->GetOwner(object_id, &owner_address) && !nested_refs.empty()) {
std::vector<ObjectID> nested_ids;
nested_ids.reserve(nested_refs.size());
for (const auto &nested_ref : nested_refs) {
nested_ids.emplace_back(ObjectRefToId(nested_ref));
}
Expand Down Expand Up @@ -792,7 +793,7 @@ void TaskManager::CompletePendingTask(const TaskID &task_id,
if (!HandleTaskReturn(object_id,
return_object,
NodeID::FromBinary(worker_addr.raylet_id()),
store_in_plasma_ids.count(object_id))) {
store_in_plasma_ids.contains(object_id))) {
if (first_execution) {
dynamic_returns_in_plasma.push_back(object_id);
}
Expand Down
58 changes: 50 additions & 8 deletions src/ray/core_worker/test/dependency_resolver_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class MockTaskFinisher : public TaskFinisherInterface {

class MockActorCreator : public ActorCreatorInterface {
public:
MockActorCreator() {}
MockActorCreator() = default;

Status RegisterActor(const TaskSpecification &task_spec) const override {
return Status::OK();
Expand Down Expand Up @@ -195,8 +195,12 @@ TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies1) {
actor_handle_id.Binary());

int num_resolved = 0;
std::promise<bool> dependencies_resolved;
actor_creator.actor_pending = true;
resolver.ResolveDependencies(task, [&](const Status &) { num_resolved++; });
resolver.ResolveDependencies(task, [&](const Status &) {
num_resolved++;
dependencies_resolved.set_value(true);
});
ASSERT_EQ(num_resolved, 0);
ASSERT_EQ(resolver.NumPendingTasks(), 1);

Expand All @@ -210,6 +214,8 @@ TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies1) {
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
auto data = RayObject(nullptr, meta_buffer, std::vector<rpc::ObjectReference>());
ASSERT_TRUE(store->Put(data, obj));
// Wait for the async callback to call
ASSERT_TRUE(dependencies_resolved.get_future().get());
ASSERT_EQ(num_resolved, 1);

ASSERT_EQ(resolver.NumPendingTasks(), 0);
Expand All @@ -231,8 +237,12 @@ TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies2) {
actor_handle_id.Binary());

int num_resolved = 0;
std::promise<bool> dependencies_resolved;
actor_creator.actor_pending = true;
resolver.ResolveDependencies(task, [&](const Status &) { num_resolved++; });
resolver.ResolveDependencies(task, [&](const Status &) {
num_resolved++;
dependencies_resolved.set_value(true);
});
ASSERT_EQ(num_resolved, 0);
ASSERT_EQ(resolver.NumPendingTasks(), 1);

Expand All @@ -246,6 +256,9 @@ TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies2) {
for (const auto &cb : actor_creator.callbacks) {
cb(Status());
}
// Wait for the async callback to call
ASSERT_TRUE(dependencies_resolved.get_future().get());

ASSERT_EQ(num_resolved, 1);
ASSERT_EQ(resolver.NumPendingTasks(), 0);
}
Expand All @@ -264,7 +277,12 @@ TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) {
TaskSpecification task;
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary());
bool ok = false;
resolver.ResolveDependencies(task, [&ok](Status) { ok = true; });
std::promise<bool> dependencies_resolved;
resolver.ResolveDependencies(task, [&](Status) {
ok = true;
dependencies_resolved.set_value(true);
});
ASSERT_TRUE(dependencies_resolved.get_future().get());
ASSERT_TRUE(ok);
ASSERT_TRUE(task.ArgByRef(0));
// Checks that the object id is still a direct call id.
Expand All @@ -287,7 +305,12 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary());
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj2.Binary());
bool ok = false;
resolver.ResolveDependencies(task, [&ok](Status) { ok = true; });
std::promise<bool> dependencies_resolved;
resolver.ResolveDependencies(task, [&](Status) {
ok = true;
dependencies_resolved.set_value(true);
});
ASSERT_TRUE(dependencies_resolved.get_future().get());
// Tests that the task proto was rewritten to have inline argument values.
ASSERT_TRUE(ok);
ASSERT_FALSE(task.ArgByRef(0));
Expand All @@ -310,11 +333,17 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary());
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj2.Binary());
bool ok = false;
resolver.ResolveDependencies(task, [&ok](Status) { ok = true; });
std::promise<bool> dependencies_resolved;
resolver.ResolveDependencies(task, [&](Status) {
ok = true;
dependencies_resolved.set_value(true);
});
ASSERT_EQ(resolver.NumPendingTasks(), 1);
ASSERT_TRUE(!ok);
ASSERT_TRUE(store->Put(*data, obj1));
ASSERT_TRUE(store->Put(*data, obj2));

ASSERT_TRUE(dependencies_resolved.get_future().get());
// Tests that the task proto was rewritten to have inline argument values after
// resolution completes.
ASSERT_TRUE(ok);
Expand All @@ -340,11 +369,17 @@ TEST(LocalDependencyResolverTest, TestInlinedObjectIds) {
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary());
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj2.Binary());
bool ok = false;
resolver.ResolveDependencies(task, [&ok](Status) { ok = true; });
std::promise<bool> dependencies_resolved;
resolver.ResolveDependencies(task, [&](Status) {
ok = true;
dependencies_resolved.set_value(true);
});
ASSERT_EQ(resolver.NumPendingTasks(), 1);
ASSERT_TRUE(!ok);
ASSERT_TRUE(store->Put(*data, obj1));
ASSERT_TRUE(store->Put(*data, obj2));

ASSERT_TRUE(dependencies_resolved.get_future().get());
// Tests that the task proto was rewritten to have inline argument values after
// resolution completes.
ASSERT_TRUE(ok);
Expand Down Expand Up @@ -385,6 +420,8 @@ TEST(LocalDependencyResolverTest, TestCancelDependencyResolution) {
ASSERT_EQ(resolver.NumPendingTasks(), 0);
}

// Even if dependencies are already local, the ResolveDependencies callbacks are still
// called asynchronously in the event loop as a different task.
TEST(LocalDependencyResolverTest, TestDependenciesAlreadyLocal) {
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto task_finisher = std::make_shared<MockTaskFinisher>();
Expand All @@ -398,7 +435,12 @@ TEST(LocalDependencyResolverTest, TestDependenciesAlreadyLocal) {
TaskSpecification task;
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj.Binary());
bool ok = false;
resolver.ResolveDependencies(task, [&ok](Status) { ok = true; });
std::promise<bool> dependencies_resolved;
resolver.ResolveDependencies(task, [&](Status) {
ok = true;
dependencies_resolved.set_value(true);
});
ASSERT_TRUE(dependencies_resolved.get_future().get());
ASSERT_TRUE(ok);
// Check for leaks.
ASSERT_EQ(resolver.NumPendingTasks(), 0);
Expand Down
12 changes: 11 additions & 1 deletion src/ray/core_worker/test/direct_actor_transport_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class ActorTaskSubmitterTest : public ::testing::TestWithParam<bool> {
return worker_client_;
})),
worker_client_(std::make_shared<MockWorkerClient>()),
store_(std::make_shared<CoreWorkerMemoryStore>()),
store_(std::make_shared<CoreWorkerMemoryStore>(&io_context)),
task_finisher_(std::make_shared<MockTaskFinisherInterface>()),
io_work(io_context),
reference_counter_(std::make_shared<MockReferenceCounter>()),
Expand Down Expand Up @@ -253,10 +253,16 @@ TEST_P(ActorTaskSubmitterTest, TestDependencies) {

// Put the dependencies in the store in the same order as task submission.
auto data = GenerateRandomObject();

// Each Put schedules a callback onto io_context, and let's run it.
ASSERT_TRUE(store_->Put(*data, obj1));
ASSERT_EQ(io_context.poll_one(), 1);
ASSERT_EQ(worker_client_->callbacks.size(), 1);

ASSERT_TRUE(store_->Put(*data, obj2));
ASSERT_EQ(io_context.poll_one(), 1);
ASSERT_EQ(worker_client_->callbacks.size(), 2);

ASSERT_THAT(worker_client_->received_seq_nos, ElementsAre(0, 1));
}

Expand Down Expand Up @@ -296,19 +302,23 @@ TEST_P(ActorTaskSubmitterTest, TestOutOfOrderDependencies) {
auto data = GenerateRandomObject();
// task2 is submitted first as we allow out of order execution.
ASSERT_TRUE(store_->Put(*data, obj2));
ASSERT_EQ(io_context.poll_one(), 1);
ASSERT_EQ(worker_client_->callbacks.size(), 1);
ASSERT_THAT(worker_client_->received_seq_nos, ElementsAre(1));
// then task1 is submitted
ASSERT_TRUE(store_->Put(*data, obj1));
ASSERT_EQ(io_context.poll_one(), 1);
ASSERT_EQ(worker_client_->callbacks.size(), 2);
ASSERT_THAT(worker_client_->received_seq_nos, ElementsAre(1, 0));
} else {
// Put the dependencies in the store in the opposite order of task
// submission.
auto data = GenerateRandomObject();
ASSERT_TRUE(store_->Put(*data, obj2));
ASSERT_EQ(io_context.poll_one(), 1);
ASSERT_EQ(worker_client_->callbacks.size(), 0);
ASSERT_TRUE(store_->Put(*data, obj1));
ASSERT_EQ(io_context.poll_one(), 1);
ASSERT_EQ(worker_client_->callbacks.size(), 2);
ASSERT_THAT(worker_client_->received_seq_nos, ElementsAre(0, 1));
}
Expand Down
Loading