Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] add option for raylet to inform whether a task should be retried #31230

Merged
merged 3 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/mock/ray/core_worker/task_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class MockTaskFinisherInterface : public TaskFinisherInterface {
rpc::ErrorType error_type,
const Status *status,
const rpc::RayErrorInfo *ray_error_info,
bool mark_task_object_failed),
bool mark_task_object_failed,
bool fail_immediately),
(override));
MOCK_METHOD(void,
OnTaskDependenciesInlined,
Expand Down
14 changes: 11 additions & 3 deletions src/ray/common/task/task_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,20 @@

namespace ray {

/// Stores the task failure reason and when this entry was created.
/// Stores the task failure reason.
struct TaskFailureEntry {
/// The task failure details.
rpc::RayErrorInfo ray_error_info;

/// The creation time of this entry.
std::chrono::steady_clock::time_point creation_time;
TaskFailureEntry(const rpc::RayErrorInfo &ray_error_info)
: ray_error_info(ray_error_info), creation_time(std::chrono::steady_clock::now()) {}

/// Whether this task should be retried.
bool should_retry;
TaskFailureEntry(const rpc::RayErrorInfo &ray_error_info, bool should_retry)
: ray_error_info(ray_error_info),
creation_time(std::chrono::steady_clock::now()),
should_retry(should_retry) {}
};

/// Argument of a task.
Expand Down
11 changes: 8 additions & 3 deletions src/ray/core_worker/task_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -540,13 +540,18 @@ bool TaskManager::FailOrRetryPendingTask(const TaskID &task_id,
rpc::ErrorType error_type,
const Status *status,
const rpc::RayErrorInfo *ray_error_info,
bool mark_task_object_failed) {
bool mark_task_object_failed,
bool fail_immediately) {
// Note that this might be the __ray_terminate__ task, so we don't log
// loudly with ERROR here.
RAY_LOG(DEBUG) << "Task attempt " << task_id << " failed with error "
<< rpc::ErrorType_Name(error_type);
const bool will_retry = RetryTaskIfPossible(
task_id, /*task_failed_due_to_oom*/ error_type == rpc::ErrorType::OUT_OF_MEMORY);
bool will_retry = false;
if (!fail_immediately) {
will_retry = RetryTaskIfPossible(
task_id, /*task_failed_due_to_oom*/ error_type == rpc::ErrorType::OUT_OF_MEMORY);
}

if (!will_retry && mark_task_object_failed) {
FailPendingTask(task_id, error_type, status, ray_error_info);
}
Expand Down
8 changes: 6 additions & 2 deletions src/ray/core_worker/task_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class TaskFinisherInterface {
rpc::ErrorType error_type,
const Status *status,
const rpc::RayErrorInfo *ray_error_info = nullptr,
bool mark_task_object_failed = true) = 0;
bool mark_task_object_failed = true,
bool fail_immediately = false) = 0;

virtual void MarkTaskWaitingForExecution(const TaskID &task_id,
const NodeID &node_id) = 0;
Expand Down Expand Up @@ -185,12 +186,15 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
/// \param[in] mark_task_object_failed whether or not it marks the task
/// return object as failed. If this is set to false, then the caller is
/// responsible for later failing or completing the task.
/// \param[in] fail_immediately whether to fail the task and ignore
/// the retries that are available.
/// \return Whether the task will be retried or not.
bool FailOrRetryPendingTask(const TaskID &task_id,
rpc::ErrorType error_type,
const Status *status = nullptr,
const rpc::RayErrorInfo *ray_error_info = nullptr,
bool mark_task_object_failed = true) override;
bool mark_task_object_failed = true,
bool fail_immediately = false) override;

/// A pending task failed. This will mark the task as failed.
/// This doesn't always mark the return object as failed
Expand Down
3 changes: 2 additions & 1 deletion src/ray/core_worker/test/dependency_resolver_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ class MockTaskFinisher : public TaskFinisherInterface {
rpc::ErrorType error_type,
const Status *status,
const rpc::RayErrorInfo *ray_error_info = nullptr,
bool mark_task_object_failed = true) override {
bool mark_task_object_failed = true,
bool fail_immediately = false) override {
num_tasks_failed++;
return true;
}
Expand Down
4 changes: 2 additions & 2 deletions src/ray/core_worker/test/direct_actor_transport_mock_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ TEST_F(DirectTaskTransportTest, ActorRegisterFailure) {
EXPECT_CALL(
*task_finisher,
FailOrRetryPendingTask(
task_spec.TaskId(), rpc::ErrorType::DEPENDENCY_RESOLUTION_FAILED, _, _, _));
task_spec.TaskId(), rpc::ErrorType::DEPENDENCY_RESOLUTION_FAILED, _, _, _, _));
register_cb(Status::IOError(""));
}

Expand All @@ -119,7 +119,7 @@ TEST_F(DirectTaskTransportTest, ActorRegisterOk) {
ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id));
actor_task_submitter->AddActorQueueIfNotExists(actor_id, -1);
ASSERT_TRUE(CheckSubmitTask(task_spec));
EXPECT_CALL(*task_finisher, FailOrRetryPendingTask(_, _, _, _, _)).Times(0);
EXPECT_CALL(*task_finisher, FailOrRetryPendingTask(_, _, _, _, _, _)).Times(0);
register_cb(Status::OK());
}

Expand Down
32 changes: 16 additions & 16 deletions src/ray/core_worker/test/direct_actor_transport_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ TEST_P(DirectActorSubmitterTest, TestSubmitTask) {

EXPECT_CALL(*task_finisher_, CompletePendingTask(_, _, _, _))
.Times(worker_client_->callbacks.size());
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(_, _, _, _, _)).Times(0);
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(_, _, _, _, _, _)).Times(0);
while (!worker_client_->callbacks.empty()) {
ASSERT_TRUE(worker_client_->ReplyPushTask());
}
Expand Down Expand Up @@ -314,18 +314,18 @@ TEST_P(DirectActorSubmitterTest, TestActorDead) {
ASSERT_EQ(worker_client_->callbacks.size(), 1);

// Simulate the actor dying. All in-flight tasks should get failed.
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task1.TaskId(), _, _, _, _))
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task1.TaskId(), _, _, _, _, _))
.Times(1);
EXPECT_CALL(*task_finisher_, CompletePendingTask(_, _, _, _)).Times(0);
while (!worker_client_->callbacks.empty()) {
ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError("")));
}

EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(_, _, _, _, _)).Times(0);
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(_, _, _, _, _, _)).Times(0);
const auto death_cause = CreateMockDeathCause();
submitter_.DisconnectActor(actor_id, 1, /*dead=*/false, death_cause);
// Actor marked as dead. All queued tasks should get failed.
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task2.TaskId(), _, _, _, _))
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task2.TaskId(), _, _, _, _, _))
.Times(1);
submitter_.DisconnectActor(actor_id, 2, /*dead=*/true, death_cause);
}
Expand All @@ -352,9 +352,9 @@ TEST_P(DirectActorSubmitterTest, TestActorRestartNoRetry) {
ASSERT_TRUE(CheckSubmitTask(task3));

EXPECT_CALL(*task_finisher_, CompletePendingTask(task1.TaskId(), _, _, _)).Times(1);
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task2.TaskId(), _, _, _, _))
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task2.TaskId(), _, _, _, _, _))
.Times(1);
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task3.TaskId(), _, _, _, _))
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task3.TaskId(), _, _, _, _, _))
.Times(1);
EXPECT_CALL(*task_finisher_, CompletePendingTask(task4.TaskId(), _, _, _)).Times(1);
// First task finishes. Second task fails.
Expand Down Expand Up @@ -407,10 +407,10 @@ TEST_P(DirectActorSubmitterTest, TestActorRestartRetry) {
// All tasks will eventually finish.
EXPECT_CALL(*task_finisher_, CompletePendingTask(_, _, _, _)).Times(4);
// Tasks 2 and 3 will be retried.
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task2.TaskId(), _, _, _, _))
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task2.TaskId(), _, _, _, _, _))
.Times(1)
.WillRepeatedly(Return(true));
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task3.TaskId(), _, _, _, _))
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task3.TaskId(), _, _, _, _, _))
.Times(1)
.WillRepeatedly(Return(true));
// First task finishes. Second task fails.
Expand Down Expand Up @@ -467,7 +467,7 @@ TEST_P(DirectActorSubmitterTest, TestActorRestartOutOfOrderRetry) {
EXPECT_CALL(*task_finisher_, CompletePendingTask(_, _, _, _)).Times(3);

// Tasks 2 will be retried
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task2.TaskId(), _, _, _, _))
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task2.TaskId(), _, _, _, _, _))
.Times(1)
.WillRepeatedly(Return(true));
// First task finishes. Second task hang. Third task finishes.
Expand Down Expand Up @@ -553,7 +553,7 @@ TEST_P(DirectActorSubmitterTest, TestActorRestartOutOfOrderGcs) {
// Tasks submitted when the actor is in RESTARTING state will fail immediately.
// This happens in an io_service.post. Search `SendPendingTasks_ForceFail` to locate
// the code.
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task.TaskId(), _, _, _, _))
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task.TaskId(), _, _, _, _, _))
.Times(1);
ASSERT_EQ(io_context.poll_one(), 1);

Expand All @@ -574,7 +574,7 @@ TEST_P(DirectActorSubmitterTest, TestActorRestartOutOfOrderGcs) {
ASSERT_EQ(num_clients_connected_, 2);
// Submit a task.
task = CreateActorTaskHelper(actor_id, worker_id, 4);
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task.TaskId(), _, _, _, _))
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task.TaskId(), _, _, _, _, _))
.Times(1);
ASSERT_FALSE(CheckSubmitTask(task));
}
Expand Down Expand Up @@ -605,20 +605,20 @@ TEST_P(DirectActorSubmitterTest, TestActorRestartFailInflightTasks) {
ASSERT_TRUE(CheckSubmitTask(task3));
// Actor failed, but the task replies are delayed (or in some scenarios, lost).
// We should still be able to fail the inflight tasks.
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task2.TaskId(), _, _, _, _))
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task2.TaskId(), _, _, _, _, _))
.Times(1);
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task3.TaskId(), _, _, _, _))
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task3.TaskId(), _, _, _, _, _))
.Times(1);
const auto death_cause = CreateMockDeathCause();
submitter_.DisconnectActor(actor_id, 1, /*dead=*/false, death_cause);

// The task replies are now received. Since the tasks are already failed, they will not
// be marked as failed or finished again.
EXPECT_CALL(*task_finisher_, CompletePendingTask(task2.TaskId(), _, _, _)).Times(0);
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task2.TaskId(), _, _, _, _))
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task2.TaskId(), _, _, _, _, _))
.Times(0);
EXPECT_CALL(*task_finisher_, CompletePendingTask(task3.TaskId(), _, _, _)).Times(0);
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task3.TaskId(), _, _, _, _))
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task3.TaskId(), _, _, _, _, _))
.Times(0);
// Task 2 replied with OK.
ASSERT_TRUE(worker_client_->ReplyPushTask(Status::OK()));
Expand Down Expand Up @@ -652,7 +652,7 @@ TEST_P(DirectActorSubmitterTest, TestActorRestartFastFail) {
auto task2 = CreateActorTaskHelper(actor_id, worker_id, 1);
ASSERT_TRUE(CheckSubmitTask(task2));
EXPECT_CALL(*task_finisher_, CompletePendingTask(task2.TaskId(), _, _, _)).Times(0);
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task2.TaskId(), _, _, _, _))
EXPECT_CALL(*task_finisher_, FailOrRetryPendingTask(task2.TaskId(), _, _, _, _, _))
.Times(1);
ASSERT_EQ(io_context.poll_one(), 1);
}
Expand Down
7 changes: 4 additions & 3 deletions src/ray/core_worker/test/direct_task_transport_mock_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ TEST_F(DirectTaskTransportTest, ActorCreationFail) {
auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000");
auto task_spec = GetCreatingTaskSpec(actor_id);
EXPECT_CALL(*task_finisher, CompletePendingTask(_, _, _, _)).Times(0);
EXPECT_CALL(*task_finisher,
FailOrRetryPendingTask(
task_spec.TaskId(), rpc::ErrorType::ACTOR_CREATION_FAILED, _, _, true));
EXPECT_CALL(
*task_finisher,
FailOrRetryPendingTask(
task_spec.TaskId(), rpc::ErrorType::ACTOR_CREATION_FAILED, _, _, true, false));
rpc::ClientCallback<rpc::CreateActorReply> create_cb;
EXPECT_CALL(*actor_creator, AsyncCreateActor(task_spec, _))
.WillOnce(DoAll(SaveArg<1>(&create_cb), Return(Status::OK())));
Expand Down
3 changes: 2 additions & 1 deletion src/ray/core_worker/test/direct_task_transport_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ class MockTaskFinisher : public TaskFinisherInterface {
rpc::ErrorType error_type,
const Status *status,
const rpc::RayErrorInfo *ray_error_info = nullptr,
bool mark_task_object_failed = true) override {
bool mark_task_object_failed = true,
bool fail_immediately = false) override {
num_tasks_failed++;
return true;
}
Expand Down
52 changes: 52 additions & 0 deletions src/ray/core_worker/test/task_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,58 @@ TEST_F(TaskManagerTest, TestTaskNotRetriableOomFailsImmediatelyEvenWithOomRetryC
ASSERT_EQ(stored_error, rpc::ErrorType::OUT_OF_MEMORY);
}

TEST_F(TaskManagerTest, TestFailsImmediatelyOverridesRetry) {
RayConfig::instance().initialize(R"({"task_oom_retries": 1})");

{
ray::rpc::ErrorType error = rpc::ErrorType::OUT_OF_MEMORY;

rpc::Address caller_address;
auto spec = CreateTaskHelper(1, {});
manager_.AddPendingTask(caller_address, spec, "", /*max retries*/ 10);
auto return_id = spec.ReturnId(0);

manager_.FailOrRetryPendingTask(spec.TaskId(),
error,
/*status*/ nullptr,
/*error info*/ nullptr,
/*mark object failed*/ true,
/*fail immediately*/ true);

std::vector<std::shared_ptr<RayObject>> results;
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results));
ASSERT_EQ(results.size(), 1);
rpc::ErrorType stored_error;
ASSERT_TRUE(results[0]->IsException(&stored_error));
ASSERT_EQ(stored_error, error);
}

{
ray::rpc::ErrorType error = rpc::ErrorType::WORKER_DIED;

rpc::Address caller_address;
auto spec = CreateTaskHelper(1, {});
manager_.AddPendingTask(caller_address, spec, "", /*max retries*/ 10);
auto return_id = spec.ReturnId(0);

manager_.FailOrRetryPendingTask(spec.TaskId(),
error,
/*status*/ nullptr,
/*error info*/ nullptr,
/*mark object failed*/ true,
/*fail immediately*/ true);

std::vector<std::shared_ptr<RayObject>> results;
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results));
ASSERT_EQ(results.size(), 1);
rpc::ErrorType stored_error;
ASSERT_TRUE(results[0]->IsException(&stored_error));
ASSERT_EQ(stored_error, error);
}
}

// Test to make sure that the task spec and dependencies for an object are
// evicted when lineage pinning is disabled in the ReferenceCounter.
TEST_F(TaskManagerTest, TestLineageEvicted) {
Expand Down
3 changes: 2 additions & 1 deletion src/ray/core_worker/transport/direct_actor_task_submitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,8 @@ void CoreWorkerDirectActorTaskSubmitter::HandlePushTaskReply(
error_type,
&status,
&error_info,
/*mark_task_object_failed*/ is_actor_dead);
/*mark_task_object_failed*/ is_actor_dead,
/*fail_immediatedly*/ false);

if (!is_actor_dead && !will_retry) {
// No retry == actor is dead.
Expand Down
5 changes: 4 additions & 1 deletion src/ray/core_worker/transport/direct_task_transport.cc
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ void CoreWorkerDirectTaskSubmitter::HandleGetTaskFailureCause(
const rpc::GetTaskFailureCauseReply &get_task_failure_cause_reply) {
rpc::ErrorType task_error_type = rpc::ErrorType::WORKER_DIED;
std::unique_ptr<rpc::RayErrorInfo> error_info;
bool fail_immediately = false;
if (get_task_failure_cause_reply_status.ok()) {
RAY_LOG(DEBUG) << "Task failure cause for task " << task_id << ": "
<< ray::gcs::RayErrorInfoToString(
Expand Down Expand Up @@ -679,7 +680,9 @@ void CoreWorkerDirectTaskSubmitter::HandleGetTaskFailureCause(
task_id,
is_actor ? rpc::ErrorType::ACTOR_DIED : task_error_type,
&task_execution_status,
error_info.get()));
error_info.get(),
/*mark_task_object_failed*/ true,
fail_immediately));
}

Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
Expand Down
1 change: 1 addition & 0 deletions src/ray/protobuf/node_manager.proto
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ message GetTaskFailureCauseRequest {

message GetTaskFailureCauseReply {
optional RayErrorInfo failure_cause = 1;
bool fail_task_immediately = 2;
}

// Service for inter-node-manager communication.
Expand Down
Loading