From 0131353d42f20f28034480d45918130525fb0377 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 26 Jun 2019 05:31:19 +0800 Subject: [PATCH] [gRPC] Migrate gcs data structures to protobuf (#5024) --- BUILD.bazel | 96 ++-- bazel/ray_deps_build_all.bzl | 4 + bazel/ray_deps_setup.bzl | 11 +- doc/source/conf.py | 15 +- java/BUILD.bazel | 51 +-- java/dependencies.bzl | 1 + ...modify_generated_java_flatbuffers_files.py | 20 +- java/runtime/pom.xml | 5 + .../java/org/ray/runtime/gcs/GcsClient.java | 69 +-- .../runtime/objectstore/ObjectStoreProxy.java | 12 +- python/ray/gcs_utils.py | 71 ++- python/ray/monitor.py | 33 +- python/ray/state.py | 230 ++++------ python/ray/tests/cluster_utils.py | 4 +- python/ray/tests/test_basic.py | 14 +- python/ray/tests/test_failure.py | 5 +- python/ray/utils.py | 8 +- python/ray/worker.py | 40 +- python/setup.py | 1 + src/ray/gcs/client.cc | 4 - src/ray/gcs/client.h | 6 - src/ray/gcs/client_test.cc | 353 +++++++-------- src/ray/gcs/format/gcs.fbs | 281 +----------- src/ray/gcs/redis_context.h | 15 +- src/ray/gcs/redis_module/ray_redis_module.cc | 209 ++++----- src/ray/gcs/tables.cc | 417 ++++++++---------- src/ray/gcs/tables.h | 136 +++--- src/ray/object_manager/object_directory.cc | 34 +- src/ray/object_manager/object_manager.cc | 49 +- src/ray/object_manager/object_manager.h | 4 +- .../test/object_manager_stress_test.cc | 30 +- .../test/object_manager_test.cc | 36 +- src/ray/protobuf/gcs.proto | 280 ++++++++++++ src/ray/raylet/actor_registration.cc | 51 +-- src/ray/raylet/actor_registration.h | 24 +- src/ray/raylet/lineage_cache.cc | 37 +- src/ray/raylet/lineage_cache.h | 28 +- src/ray/raylet/lineage_cache_test.cc | 28 +- src/ray/raylet/monitor.cc | 15 +- src/ray/raylet/monitor.h | 8 +- src/ray/raylet/node_manager.cc | 237 +++++----- src/ray/raylet/node_manager.h | 26 +- src/ray/raylet/raylet.cc | 24 +- src/ray/raylet/raylet.h | 2 + src/ray/raylet/reconstruction_policy.cc | 10 +- src/ray/raylet/reconstruction_policy.h | 2 + src/ray/raylet/reconstruction_policy_test.cc | 42 +- src/ray/raylet/task_dependency_manager.cc | 8 +- src/ray/raylet/task_dependency_manager.h | 2 + .../raylet/task_dependency_manager_test.cc | 2 +- src/ray/raylet/worker_pool.cc | 4 +- src/ray/rpc/util.h | 13 + 52 files changed, 1465 insertions(+), 1642 deletions(-) create mode 100644 src/ray/protobuf/gcs.proto diff --git a/BUILD.bazel b/BUILD.bazel index da36eec0cf577..bc9e6bcd8006f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,22 +1,55 @@ # Bazel build # C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html -load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library") +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") +load("@build_stack_rules_proto//python:python_proto_compile.bzl", "python_proto_compile") load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("@//bazel:ray.bzl", "flatbuffer_py_library") load("@//bazel:cython_library.bzl", "pyx_library") COPTS = ["-DRAY_USE_GLOG"] -# Node manager gRPC lib. -grpc_proto_library( - name = "node_manager_grpc_lib", +# === Begin of protobuf definitions === + +proto_library( + name = "gcs_proto", + srcs = ["src/ray/protobuf/gcs.proto"], + visibility = ["//java:__subpackages__"], +) + +cc_proto_library( + name = "gcs_cc_proto", + deps = [":gcs_proto"], +) + +python_proto_compile( + name = "gcs_py_proto", + deps = [":gcs_proto"], +) + +proto_library( + name = "node_manager_proto", srcs = ["src/ray/protobuf/node_manager.proto"], ) +cc_proto_library( + name = "node_manager_cc_proto", + deps = ["node_manager_proto"], +) + +# === End of protobuf definitions === + +# Node manager gRPC lib. +cc_grpc_library( + name = "node_manager_cc_grpc", + srcs = [":node_manager_proto"], + grpc_only = True, + deps = [":node_manager_cc_proto"], +) + # Node manager server and client. cc_library( - name = "node_manager_rpc_lib", + name = "node_manager_rpc", srcs = glob([ "src/ray/rpc/*.cc", ]), @@ -25,7 +58,7 @@ cc_library( ]), copts = COPTS, deps = [ - ":node_manager_grpc_lib", + ":node_manager_cc_grpc", ":ray_common", "@boost//:asio", "@com_github_grpc_grpc//:grpc++", @@ -114,7 +147,7 @@ cc_library( ":gcs", ":gcs_fbs", ":node_manager_fbs", - ":node_manager_rpc_lib", + ":node_manager_rpc", ":object_manager", ":ray_common", ":ray_util", @@ -422,9 +455,11 @@ cc_library( "src/ray/gcs/format", ], deps = [ + ":gcs_cc_proto", ":gcs_fbs", ":hiredis", ":node_manager_fbs", + ":node_manager_rpc", ":ray_common", ":ray_util", ":stats_lib", @@ -555,46 +590,6 @@ filegroup( visibility = ["//java:__subpackages__"], ) -flatbuffer_py_library( - name = "python_gcs_fbs", - srcs = [ - ":gcs_fbs_file", - ], - outs = [ - "ActorCheckpointIdData.py", - "ActorState.py", - "ActorTableData.py", - "Arg.py", - "ClassTableData.py", - "ClientTableData.py", - "ConfigTableData.py", - "CustomSerializerData.py", - "DriverTableData.py", - "EntryType.py", - "ErrorTableData.py", - "ErrorType.py", - "FunctionTableData.py", - "GcsEntry.py", - "HeartbeatBatchTableData.py", - "HeartbeatTableData.py", - "Language.py", - "ObjectTableData.py", - "ProfileEvent.py", - "ProfileTableData.py", - "RayResource.py", - "ResourcePair.py", - "SchedulingState.py", - "TablePrefix.py", - "TablePubsub.py", - "TaskInfo.py", - "TaskLeaseData.py", - "TaskReconstructionData.py", - "TaskTableData.py", - "TaskTableTestAndUpdate.py", - ], - out_prefix = "python/ray/core/generated/", -) - flatbuffer_py_library( name = "python_node_manager_fbs", srcs = [ @@ -679,6 +674,7 @@ cc_binary( linkstatic = 1, visibility = ["//java:__subpackages__"], deps = [ + ":gcs_cc_proto", ":ray_common", ], ) @@ -688,7 +684,7 @@ genrule( srcs = [ "python/ray/_raylet.so", "//:python_sources", - "//:python_gcs_fbs", + "//:gcs_py_proto", "//:python_node_manager_fbs", "//:redis-server", "//:redis-cli", @@ -710,11 +706,13 @@ genrule( cp -f $(location //:raylet_monitor) $$WORK_DIR/python/ray/core/src/ray/raylet/ && cp -f $(location @plasma//:plasma_store_server) $$WORK_DIR/python/ray/core/src/plasma/ && cp -f $(location //:raylet) $$WORK_DIR/python/ray/core/src/ray/raylet/ && - for f in $(locations //:python_gcs_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/; done && mkdir -p $$WORK_DIR/python/ray/core/generated/ray/protocol/ && for f in $(locations //:python_node_manager_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/; done && + for f in $(locations //:gcs_py_proto); do + cp -f $$f $$WORK_DIR/python/ray/core/generated/; + done && echo $$WORK_DIR > $@ """, local = 1, diff --git a/bazel/ray_deps_build_all.bzl b/bazel/ray_deps_build_all.bzl index 3e1e1838a59a3..eda88bece7d22 100644 --- a/bazel/ray_deps_build_all.bzl +++ b/bazel/ray_deps_build_all.bzl @@ -4,6 +4,8 @@ load("@com_github_jupp0r_prometheus_cpp//:repositories.bzl", "prometheus_cpp_rep load("@com_github_ray_project_ray//bazel:python_configure.bzl", "python_configure") load("@com_github_checkstyle_java//:repo.bzl", "checkstyle_deps") load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") +load("@build_stack_rules_proto//java:deps.bzl", "java_proto_compile") +load("@build_stack_rules_proto//python:deps.bzl", "python_proto_compile") def ray_deps_build_all(): @@ -13,4 +15,6 @@ def ray_deps_build_all(): prometheus_cpp_repositories() python_configure(name = "local_config_python") grpc_deps() + java_proto_compile() + python_proto_compile() diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index e6dc215856996..aa322654cf9ff 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -105,7 +105,14 @@ def ray_deps_setup(): http_archive( name = "com_github_grpc_grpc", urls = [ - "https://github.com/grpc/grpc/archive/7741e806a213cba63c96234f16d712a8aa101a49.tar.gz", + "https://github.com/grpc/grpc/archive/76a381869413834692b8ed305fbe923c0f9c4472.tar.gz", ], - strip_prefix = "grpc-7741e806a213cba63c96234f16d712a8aa101a49", + strip_prefix = "grpc-76a381869413834692b8ed305fbe923c0f9c4472", + ) + + http_archive( + name = "build_stack_rules_proto", + urls = ["https://github.com/stackb/rules_proto/archive/b93b544f851fdcd3fc5c3d47aee3b7ca158a8841.tar.gz"], + sha256 = "c62f0b442e82a6152fcd5b1c0b7c4028233a9e314078952b6b04253421d56d61", + strip_prefix = "rules_proto-b93b544f851fdcd3fc5c3d47aee3b7ca158a8841", ) diff --git a/doc/source/conf.py b/doc/source/conf.py index 98fb3e0d02dd3..5cf6b01217f97 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -23,20 +23,7 @@ "gym.spaces", "ray._raylet", "ray.core.generated", - "ray.core.generated.ActorCheckpointIdData", - "ray.core.generated.ClientTableData", - "ray.core.generated.DriverTableData", - "ray.core.generated.EntryType", - "ray.core.generated.ErrorTableData", - "ray.core.generated.ErrorType", - "ray.core.generated.GcsEntry", - "ray.core.generated.HeartbeatBatchTableData", - "ray.core.generated.HeartbeatTableData", - "ray.core.generated.Language", - "ray.core.generated.ObjectTableData", - "ray.core.generated.ProfileTableData", - "ray.core.generated.TablePrefix", - "ray.core.generated.TablePubsub", + "ray.core.generated.gcs_pb2", "ray.core.generated.ray.protocol.Task", "scipy", "scipy.signal", diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 80ccabccfc121..4960434af1804 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -1,4 +1,5 @@ load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module") +load("@build_stack_rules_proto//java:java_proto_compile.bzl", "java_proto_compile") exports_files([ "testng.xml", @@ -50,6 +51,7 @@ define_java_module( name = "runtime", additional_srcs = [ ":generate_java_gcs_fbs", + ":gcs_java_proto", ], additional_resources = [ ":java_native_deps", @@ -68,6 +70,7 @@ define_java_module( "@plasma//:org_apache_arrow_arrow_plasma", "@maven//:com_github_davidmoten_flatbuffers_java", "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", "@maven//:com_typesafe_config", "@maven//:commons_io_commons_io", "@maven//:de_ruedigermoeller_fst", @@ -148,38 +151,16 @@ java_binary( ], ) +java_proto_compile( + name = "gcs_java_proto", + deps = ["@//:gcs_proto"], +) + flatbuffers_generated_files = [ - "ActorCheckpointData.java", - "ActorCheckpointIdData.java", - "ActorState.java", - "ActorTableData.java", "Arg.java", - "ClassTableData.java", - "ClientTableData.java", - "ConfigTableData.java", - "CustomSerializerData.java", - "DriverTableData.java", - "EntryType.java", - "ErrorTableData.java", - "ErrorType.java", - "FunctionTableData.java", - "GcsEntry.java", - "HeartbeatBatchTableData.java", - "HeartbeatTableData.java", "Language.java", - "ObjectTableData.java", - "ProfileEvent.java", - "ProfileTableData.java", - "RayResource.java", - "ResourcePair.java", - "SchedulingState.java", - "TablePrefix.java", - "TablePubsub.java", "TaskInfo.java", - "TaskLeaseData.java", - "TaskReconstructionData.java", - "TaskTableData.java", - "TaskTableTestAndUpdate.java", + "ResourcePair.java", ] flatbuffer_java_library( @@ -198,7 +179,7 @@ genrule( cmd = """ for f in $(locations //java:java_gcs_fbs); do chmod +w $$f - cp -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated + mv -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated done python $$(pwd)/java/modify_generated_java_flatbuffers_files.py $(@D)/.. """, @@ -221,8 +202,10 @@ filegroup( genrule( name = "gen_maven_deps", srcs = [ - ":java_native_deps", + ":gcs_java_proto", ":generate_java_gcs_fbs", + ":java_native_deps", + ":copy_pom_file", "@plasma//:org_apache_arrow_arrow_plasma", ], outs = ["gen_maven_deps.out"], @@ -237,10 +220,15 @@ genrule( chmod +w $$f cp $$f $$NATIVE_DEPS_DIR done - # Copy flatbuffers-generated files + # Copy protobuf-generated files. GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated rm -rf $$GENERATED_DIR mkdir -p $$GENERATED_DIR + for f in $(locations //java:gcs_java_proto); do + unzip $$f + mv org/ray/runtime/generated/* $$GENERATED_DIR + done + # Copy flatbuffers-generated files for f in $(locations //java:generate_java_gcs_fbs); do cp $$f $$GENERATED_DIR done @@ -250,6 +238,7 @@ genrule( echo $$(date) > $@ """, local = 1, + tags = ["no-cache"], ) genrule( diff --git a/java/dependencies.bzl b/java/dependencies.bzl index 7c716166d3995..ef667137562b7 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -6,6 +6,7 @@ def gen_java_deps(): "com.beust:jcommander:1.72", "com.github.davidmoten:flatbuffers-java:1.9.0.1", "com.google.guava:guava:27.0.1-jre", + "com.google.protobuf:protobuf-java:3.8.0", "com.puppycrawl.tools:checkstyle:8.15", "com.sun.xml.bind:jaxb-core:2.3.0", "com.sun.xml.bind:jaxb-impl:2.3.0", diff --git a/java/modify_generated_java_flatbuffers_files.py b/java/modify_generated_java_flatbuffers_files.py index c1b723f25f8d4..5bf62e56d7e47 100644 --- a/java/modify_generated_java_flatbuffers_files.py +++ b/java/modify_generated_java_flatbuffers_files.py @@ -4,7 +4,6 @@ import os import sys - """ This script is used for modifying the generated java flatbuffer files for the reason: The package declaration in Java is different @@ -21,19 +20,18 @@ PACKAGE_DECLARATION = "package org.ray.runtime.generated;" -def add_new_line(file, line_num, text): +def add_package(file): with open(file, "r") as file_handler: lines = file_handler.readlines() - if (line_num <= 0) or (line_num > len(lines) + 1): - return False - lines.insert(line_num - 1, text + os.linesep) + if "FlatBuffers" not in lines[0]: + return + + lines.insert(1, PACKAGE_DECLARATION + os.linesep) with open(file, "w") as file_handler: for line in lines: file_handler.write(line) - return True - def add_package_declarations(generated_root_path): file_names = os.listdir(generated_root_path) @@ -41,15 +39,11 @@ def add_package_declarations(generated_root_path): if not file_name.endswith(".java"): continue full_name = os.path.join(generated_root_path, file_name) - success = add_new_line(full_name, 2, PACKAGE_DECLARATION) - if not success: - raise RuntimeError("Failed to add package declarations, " - "file name is %s" % full_name) + add_package(full_name) if __name__ == "__main__": ray_home = sys.argv[1] root_path = os.path.join( - ray_home, - "java/runtime/src/main/java/org/ray/runtime/generated") + ray_home, "java/runtime/src/main/java/org/ray/runtime/generated") add_package_declarations(root_path) diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index c75e2eeef13f1..e13dd95f927f3 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -41,6 +41,11 @@ guava 27.0.1-jre + + com.google.protobuf + protobuf-java + 3.8.0 + com.typesafe config diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index 431b48ded58c6..17c248ed0a57f 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -1,7 +1,7 @@ package org.ray.runtime.gcs; import com.google.common.base.Preconditions; -import java.nio.ByteBuffer; +import com.google.protobuf.InvalidProtocolBufferException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -13,10 +13,10 @@ import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.api.runtimecontext.NodeInfo; -import org.ray.runtime.generated.ActorCheckpointIdData; -import org.ray.runtime.generated.ClientTableData; -import org.ray.runtime.generated.EntryType; -import org.ray.runtime.generated.TablePrefix; +import org.ray.runtime.generated.Gcs.ActorCheckpointIdData; +import org.ray.runtime.generated.Gcs.ClientTableData; +import org.ray.runtime.generated.Gcs.ClientTableData.EntryType; +import org.ray.runtime.generated.Gcs.TablePrefix; import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -51,7 +51,7 @@ public GcsClient(String redisAddress, String redisPassword) { } public List getAllNodeInfo() { - final String prefix = TablePrefix.name(TablePrefix.CLIENT); + final String prefix = TablePrefix.CLIENT.toString(); final byte[] key = ArrayUtils.addAll(prefix.getBytes(), UniqueId.NIL.getBytes()); List results = primary.lrange(key, 0, -1); @@ -63,36 +63,42 @@ public List getAllNodeInfo() { Map clients = new HashMap<>(); for (byte[] result : results) { Preconditions.checkNotNull(result); - ClientTableData data = ClientTableData.getRootAsClientTableData(ByteBuffer.wrap(result)); - final UniqueId clientId = UniqueId.fromByteBuffer(data.clientIdAsByteBuffer()); + ClientTableData data = null; + try { + data = ClientTableData.parseFrom(result); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Received invalid protobuf data from GCS."); + } + final UniqueId clientId = UniqueId + .fromByteBuffer(data.getClientId().asReadOnlyByteBuffer()); - if (data.entryType() == EntryType.INSERTION) { + if (data.getEntryType() == EntryType.INSERTION) { //Code path of node insertion. Map resources = new HashMap<>(); // Compute resources. Preconditions.checkState( - data.resourcesTotalLabelLength() == data.resourcesTotalCapacityLength()); - for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { - resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); + data.getResourcesTotalLabelCount() == data.getResourcesTotalCapacityCount()); + for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { + resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i)); } NodeInfo nodeInfo = new NodeInfo( - clientId, data.nodeManagerAddress(), true, resources); + clientId, data.getNodeManagerAddress(), true, resources); clients.put(clientId, nodeInfo); - } else if (data.entryType() == EntryType.RES_CREATEUPDATE) { + } else if (data.getEntryType() == EntryType.RES_CREATEUPDATE) { Preconditions.checkState(clients.containsKey(clientId)); NodeInfo nodeInfo = clients.get(clientId); - for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { - nodeInfo.resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); + for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { + nodeInfo.resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i)); } - } else if (data.entryType() == EntryType.RES_DELETE) { + } else if (data.getEntryType() == EntryType.RES_DELETE) { Preconditions.checkState(clients.containsKey(clientId)); NodeInfo nodeInfo = clients.get(clientId); - for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { - nodeInfo.resources.remove(data.resourcesTotalLabel(i)); + for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { + nodeInfo.resources.remove(data.getResourcesTotalLabel(i)); } } else { // Code path of node deletion. - Preconditions.checkState(data.entryType() == EntryType.DELETION); + Preconditions.checkState(data.getEntryType() == EntryType.DELETION); NodeInfo nodeInfo = new NodeInfo(clientId, clients.get(clientId).nodeAddress, false, clients.get(clientId).resources); clients.put(clientId, nodeInfo); @@ -107,7 +113,7 @@ public List getAllNodeInfo() { */ public boolean actorExists(UniqueId actorId) { byte[] key = ArrayUtils.addAll( - TablePrefix.name(TablePrefix.ACTOR).getBytes(), actorId.getBytes()); + TablePrefix.ACTOR.toString().getBytes(), actorId.getBytes()); return primary.exists(key); } @@ -115,7 +121,7 @@ public boolean actorExists(UniqueId actorId) { * Query whether the raylet task exists in Gcs. */ public boolean rayletTaskExistsInGcs(TaskId taskId) { - byte[] key = ArrayUtils.addAll(TablePrefix.name(TablePrefix.RAYLET_TASK).getBytes(), + byte[] key = ArrayUtils.addAll(TablePrefix.RAYLET_TASK.toString().getBytes(), taskId.getBytes()); RedisClient client = getShardClient(taskId); return client.exists(key); @@ -126,19 +132,26 @@ public boolean rayletTaskExistsInGcs(TaskId taskId) { */ public List getCheckpointsForActor(UniqueId actorId) { List checkpoints = new ArrayList<>(); - final String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID); + final String prefix = TablePrefix.ACTOR_CHECKPOINT_ID.toString(); final byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes()); RedisClient client = getShardClient(actorId); byte[] result = client.get(key); if (result != null) { - ActorCheckpointIdData data = - ActorCheckpointIdData.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result)); - UniqueId[] checkpointIds = IdUtil.getUniqueIdsFromByteBuffer( - data.checkpointIdsAsByteBuffer()); + ActorCheckpointIdData data = null; + try { + data = ActorCheckpointIdData.parseFrom(result); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Received invalid protobuf data from GCS."); + } + UniqueId[] checkpointIds = new UniqueId[data.getCheckpointIdsCount()]; + for (int i = 0; i < checkpointIds.length; i++) { + checkpointIds[i] = UniqueId + .fromByteBuffer(data.getCheckpointIds(i).asReadOnlyByteBuffer()); + } for (int i = 0; i < checkpointIds.length; i++) { - checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i))); + checkpoints.add(new Checkpoint(checkpointIds[i], data.getTimestamps(i))); } } checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp)); diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index f9e310249a352..1a7e4701c22b3 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -16,7 +16,7 @@ import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.config.RunMode; -import org.ray.runtime.generated.ErrorType; +import org.ray.runtime.generated.Gcs.ErrorType; import org.ray.runtime.util.IdUtil; import org.ray.runtime.util.Serializer; import org.slf4j.Logger; @@ -29,12 +29,12 @@ public class ObjectStoreProxy { private static final Logger LOGGER = LoggerFactory.getLogger(ObjectStoreProxy.class); - private static final byte[] WORKER_EXCEPTION_META = String.valueOf(ErrorType.WORKER_DIED) - .getBytes(); - private static final byte[] ACTOR_EXCEPTION_META = String.valueOf(ErrorType.ACTOR_DIED) - .getBytes(); + private static final byte[] WORKER_EXCEPTION_META = String + .valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes(); + private static final byte[] ACTOR_EXCEPTION_META = String + .valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes(); private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String - .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes(); + .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes(); private static final byte[] RAW_TYPE_META = "RAW".getBytes(); diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index cadd197ec73f0..ba72e96f41db1 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -2,38 +2,39 @@ from __future__ import division from __future__ import print_function -import flatbuffers -import ray.core.generated.ErrorTableData - -from ray.core.generated.ActorCheckpointIdData import ActorCheckpointIdData -from ray.core.generated.ClientTableData import ClientTableData -from ray.core.generated.DriverTableData import DriverTableData -from ray.core.generated.ErrorTableData import ErrorTableData -from ray.core.generated.GcsEntry import GcsEntry -from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData -from ray.core.generated.HeartbeatTableData import HeartbeatTableData -from ray.core.generated.Language import Language -from ray.core.generated.ObjectTableData import ObjectTableData -from ray.core.generated.ProfileTableData import ProfileTableData -from ray.core.generated.TablePrefix import TablePrefix -from ray.core.generated.TablePubsub import TablePubsub - from ray.core.generated.ray.protocol.Task import Task +from ray.core.generated.gcs_pb2 import ( + ActorCheckpointIdData, + ClientTableData, + DriverTableData, + ErrorTableData, + ErrorType, + GcsEntry, + HeartbeatBatchTableData, + HeartbeatTableData, + ObjectTableData, + ProfileTableData, + TablePrefix, + TablePubsub, + TaskTableData, +) + __all__ = [ "ActorCheckpointIdData", "ClientTableData", "DriverTableData", "ErrorTableData", + "ErrorType", "GcsEntry", "HeartbeatBatchTableData", "HeartbeatTableData", - "Language", "ObjectTableData", "ProfileTableData", "TablePrefix", "TablePubsub", "Task", + "TaskTableData", "construct_error_message", ] @@ -42,13 +43,16 @@ REPORTER_CHANNEL = "RAY_REPORTER" # xray heartbeats -XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii") -XRAY_HEARTBEAT_BATCH_CHANNEL = str(TablePubsub.HEARTBEAT_BATCH).encode("ascii") +XRAY_HEARTBEAT_CHANNEL = str( + TablePubsub.Value("HEARTBEAT_PUBSUB")).encode("ascii") +XRAY_HEARTBEAT_BATCH_CHANNEL = str( + TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB")).encode("ascii") # xray driver updates -XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii") +XRAY_DRIVER_CHANNEL = str(TablePubsub.Value("DRIVER_PUBSUB")).encode("ascii") -# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs. +# These prefixes must be kept up-to-date with the TablePrefix enum in +# gcs.proto. # TODO(rkn): We should use scoped enums, in which case we should be able to # just access the flatbuffer generated values. TablePrefix_RAYLET_TASK_string = "RAYLET_TASK" @@ -70,22 +74,9 @@ def construct_error_message(driver_id, error_type, message, timestamp): Returns: The serialized object. """ - builder = flatbuffers.Builder(0) - driver_offset = builder.CreateString(driver_id.binary()) - error_type_offset = builder.CreateString(error_type) - message_offset = builder.CreateString(message) - - ray.core.generated.ErrorTableData.ErrorTableDataStart(builder) - ray.core.generated.ErrorTableData.ErrorTableDataAddDriverId( - builder, driver_offset) - ray.core.generated.ErrorTableData.ErrorTableDataAddType( - builder, error_type_offset) - ray.core.generated.ErrorTableData.ErrorTableDataAddErrorMessage( - builder, message_offset) - ray.core.generated.ErrorTableData.ErrorTableDataAddTimestamp( - builder, timestamp) - error_data_offset = ray.core.generated.ErrorTableData.ErrorTableDataEnd( - builder) - builder.Finish(error_data_offset) - - return bytes(builder.Output()) + data = ErrorTableData() + data.driver_id = driver_id.binary() + data.type = error_type + data.error_message = message + data.timestamp = timestamp + return data.SerializeToString() diff --git a/python/ray/monitor.py b/python/ray/monitor.py index c9e0424b3eb85..35597ef231e38 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -101,28 +101,26 @@ def subscribe(self, channel): def xray_heartbeat_batch_handler(self, unused_channel, data): """Handle an xray heartbeat batch message from Redis.""" - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) - heartbeat_data = gcs_entries.Entries(0) + gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) + heartbeat_data = gcs_entries.entries[0] - message = (ray.gcs_utils.HeartbeatBatchTableData. - GetRootAsHeartbeatBatchTableData(heartbeat_data, 0)) + message = ray.gcs_utils.HeartbeatBatchTableData.FromString( + heartbeat_data) - for j in range(message.BatchLength()): - heartbeat_message = message.Batch(j) - - num_resources = heartbeat_message.ResourcesTotalLabelLength() + for heartbeat_message in message.batch: + num_resources = len(heartbeat_message.resources_available_label) static_resources = {} dynamic_resources = {} for i in range(num_resources): - dyn = heartbeat_message.ResourcesAvailableLabel(i) - static = heartbeat_message.ResourcesTotalLabel(i) + dyn = heartbeat_message.resources_available_label[i] + static = heartbeat_message.resources_total_label[i] dynamic_resources[dyn] = ( - heartbeat_message.ResourcesAvailableCapacity(i)) + heartbeat_message.resources_available_capacity[i]) static_resources[static] = ( - heartbeat_message.ResourcesTotalCapacity(i)) + heartbeat_message.resources_total_capacity[i]) # Update the load metrics for this raylet. - client_id = ray.utils.binary_to_hex(heartbeat_message.ClientId()) + client_id = ray.utils.binary_to_hex(heartbeat_message.client_id) ip = self.raylet_id_to_ip_map.get(client_id) if ip: self.load_metrics.update(ip, static_resources, @@ -207,11 +205,10 @@ def xray_driver_removed_handler(self, unused_channel, data): unused_channel: The message channel. data: The message data. """ - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) - driver_data = gcs_entries.Entries(0) - message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData( - driver_data, 0) - driver_id = message.DriverId() + gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) + driver_data = gcs_entries.entries[0] + message = ray.gcs_utils.DriverTableData.FromString(driver_data) + driver_id = message.driver_id logger.info("Monitor: " "XRay Driver {} has been removed.".format( binary_to_hex(driver_id))) diff --git a/python/ray/state.py b/python/ray/state.py index 14ba49987ec4f..35f97cd65f5e4 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -10,11 +10,11 @@ import ray from ray.function_manager import FunctionDescriptor -import ray.gcs_utils -from ray.ray_constants import ID_SIZE -from ray import services -from ray.core.generated.EntryType import EntryType +from ray import ( + gcs_utils, + services, +) from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) @@ -31,9 +31,9 @@ def _parse_client_table(redis_client): A list of information about the nodes in the cluster. """ NIL_CLIENT_ID = ray.ObjectID.nil().binary() - message = redis_client.execute_command("RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.CLIENT, - "", NIL_CLIENT_ID) + message = redis_client.execute_command( + "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("CLIENT"), "", + NIL_CLIENT_ID) # Handle the case where no clients are returned. This should only # occur potentially immediately after the cluster is started. @@ -41,36 +41,31 @@ def _parse_client_table(redis_client): return [] node_info = {} - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entry = gcs_utils.GcsEntry.FromString(message) ordered_client_ids = [] # Since GCS entries are append-only, we override so that # only the latest entries are kept. - for i in range(gcs_entry.EntriesLength()): - client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData( - gcs_entry.Entries(i), 0)) + for entry in gcs_entry.entries: + client = gcs_utils.ClientTableData.FromString(entry) resources = { - decode(client.ResourcesTotalLabel(i)): - client.ResourcesTotalCapacity(i) - for i in range(client.ResourcesTotalLabelLength()) + client.resources_total_label[i]: client.resources_total_capacity[i] + for i in range(len(client.resources_total_label)) } - client_id = ray.utils.binary_to_hex(client.ClientId()) + client_id = ray.utils.binary_to_hex(client.client_id) - if client.EntryType() == EntryType.INSERTION: + if client.entry_type == gcs_utils.ClientTableData.INSERTION: ordered_client_ids.append(client_id) node_info[client_id] = { "ClientID": client_id, - "EntryType": client.EntryType(), - "NodeManagerAddress": decode( - client.NodeManagerAddress(), allow_none=True), - "NodeManagerPort": client.NodeManagerPort(), - "ObjectManagerPort": client.ObjectManagerPort(), - "ObjectStoreSocketName": decode( - client.ObjectStoreSocketName(), allow_none=True), - "RayletSocketName": decode( - client.RayletSocketName(), allow_none=True), + "EntryType": client.entry_type, + "NodeManagerAddress": client.node_manager_address, + "NodeManagerPort": client.node_manager_port, + "ObjectManagerPort": client.object_manager_port, + "ObjectStoreSocketName": client.object_store_socket_name, + "RayletSocketName": client.raylet_socket_name, "Resources": resources } @@ -79,22 +74,23 @@ def _parse_client_table(redis_client): # it cannot have previously been removed. else: assert client_id in node_info, "Client not found!" - assert node_info[client_id]["EntryType"] != EntryType.DELETION, ( - "Unexpected updation of deleted client.") + is_deletion = (node_info[client_id]["EntryType"] != + gcs_utils.ClientTableData.DELETION) + assert is_deletion, "Unexpected updation of deleted client." res_map = node_info[client_id]["Resources"] - if client.EntryType() == EntryType.RES_CREATEUPDATE: + if client.entry_type == gcs_utils.ClientTableData.RES_CREATEUPDATE: for res in resources: res_map[res] = resources[res] - elif client.EntryType() == EntryType.RES_DELETE: + elif client.entry_type == gcs_utils.ClientTableData.RES_DELETE: for res in resources: res_map.pop(res, None) - elif client.EntryType() == EntryType.DELETION: + elif client.entry_type == gcs_utils.ClientTableData.DELETION: pass # Do nothing with the resmap if client deletion else: raise RuntimeError("Unexpected EntryType {}".format( - client.EntryType())) + client.entry_type)) node_info[client_id]["Resources"] = res_map - node_info[client_id]["EntryType"] = client.EntryType() + node_info[client_id]["EntryType"] = client.entry_type # NOTE: We return the list comprehension below instead of simply doing # 'list(node_info.values())' in order to have the nodes appear in the order # that they joined the cluster. Python dictionaries do not preserve @@ -244,20 +240,19 @@ def _object_table(self, object_id): # Return information about a single object ID. message = self._execute_command(object_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.OBJECT, "", - object_id.binary()) + gcs_utils.TablePrefix.Value("OBJECT"), + "", object_id.binary()) if message is None: return {} - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entry = gcs_utils.GcsEntry.FromString(message) - assert gcs_entry.EntriesLength() > 0 + assert len(gcs_entry.entries) > 0 - entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData( - gcs_entry.Entries(0), 0) + entry = gcs_utils.ObjectTableData.FromString(gcs_entry.entries[0]) object_info = { - "DataSize": entry.ObjectSize(), - "Manager": entry.Manager(), + "DataSize": entry.object_size, + "Manager": entry.manager, } return object_info @@ -278,10 +273,9 @@ def object_table(self, object_id=None): return self._object_table(object_id) else: # Return the entire object table. - object_keys = self._keys(ray.gcs_utils.TablePrefix_OBJECT_string + - "*") + object_keys = self._keys(gcs_utils.TablePrefix_OBJECT_string + "*") object_ids_binary = { - key[len(ray.gcs_utils.TablePrefix_OBJECT_string):] + key[len(gcs_utils.TablePrefix_OBJECT_string):] for key in object_keys } @@ -301,17 +295,18 @@ def _task_table(self, task_id): A dictionary with information about the task ID in question. """ assert isinstance(task_id, ray.TaskID) - message = self._execute_command(task_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.RAYLET_TASK, - "", task_id.binary()) + message = self._execute_command( + task_id, "RAY.TABLE_LOOKUP", + gcs_utils.TablePrefix.Value("RAYLET_TASK"), "", task_id.binary()) if message is None: return {} - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) - - assert gcs_entries.EntriesLength() == 1 + gcs_entries = gcs_utils.GcsEntry.FromString(message) - task_table_message = ray.gcs_utils.Task.GetRootAsTask( - gcs_entries.Entries(0), 0) + assert len(gcs_entries.entries) == 1 + task_table_data = gcs_utils.TaskTableData.FromString( + gcs_entries.entries[0]) + task_table_message = gcs_utils.Task.GetRootAsTask( + task_table_data.task, 0) execution_spec = task_table_message.TaskExecutionSpec() task_spec = task_table_message.TaskSpecification() @@ -368,9 +363,9 @@ def task_table(self, task_id=None): return self._task_table(task_id) else: task_table_keys = self._keys( - ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*") + gcs_utils.TablePrefix_RAYLET_TASK_string + "*") task_ids_binary = [ - key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):] + key[len(gcs_utils.TablePrefix_RAYLET_TASK_string):] for key in task_table_keys ] @@ -380,27 +375,6 @@ def task_table(self, task_id=None): ray.TaskID(task_id_binary)) return results - def function_table(self, function_id=None): - """Fetch and parse the function table. - - Returns: - A dictionary that maps function IDs to information about the - function. - """ - self._check_connected() - function_table_keys = self.redis_client.keys( - ray.gcs_utils.FUNCTION_PREFIX + "*") - results = {} - for key in function_table_keys: - info = self.redis_client.hgetall(key) - function_info_parsed = { - "DriverID": binary_to_hex(info[b"driver_id"]), - "Module": decode(info[b"module"]), - "Name": decode(info[b"name"]) - } - results[binary_to_hex(info[b"function_id"])] = function_info_parsed - return results - def client_table(self): """Fetch and parse the Redis DB client table. @@ -423,37 +397,32 @@ def _profile_table(self, batch_id): # TODO(rkn): This method should support limiting the number of log # events and should also support returning a window of events. message = self._execute_command(batch_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.PROFILE, "", - batch_id.binary()) + gcs_utils.TablePrefix.Value("PROFILE"), + "", batch_id.binary()) if message is None: return [] - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entries = gcs_utils.GcsEntry.FromString(message) profile_events = [] - for i in range(gcs_entries.EntriesLength()): - profile_table_message = ( - ray.gcs_utils.ProfileTableData.GetRootAsProfileTableData( - gcs_entries.Entries(i), 0)) - - component_type = decode(profile_table_message.ComponentType()) - component_id = binary_to_hex(profile_table_message.ComponentId()) - node_ip_address = decode( - profile_table_message.NodeIpAddress(), allow_none=True) + for entry in gcs_entries.entries: + profile_table_message = gcs_utils.ProfileTableData.FromString( + entry) - for j in range(profile_table_message.ProfileEventsLength()): - profile_event_message = profile_table_message.ProfileEvents(j) + component_type = profile_table_message.component_type + component_id = binary_to_hex(profile_table_message.component_id) + node_ip_address = profile_table_message.node_ip_address + for profile_event_message in profile_table_message.profile_events: profile_event = { - "event_type": decode(profile_event_message.EventType()), + "event_type": profile_event_message.event_type, "component_id": component_id, "node_ip_address": node_ip_address, "component_type": component_type, - "start_time": profile_event_message.StartTime(), - "end_time": profile_event_message.EndTime(), - "extra_data": json.loads( - decode(profile_event_message.ExtraData())), + "start_time": profile_event_message.start_time, + "end_time": profile_event_message.end_time, + "extra_data": json.loads(profile_event_message.extra_data), } profile_events.append(profile_event) @@ -462,10 +431,10 @@ def _profile_table(self, batch_id): def profile_table(self): self._check_connected() - profile_table_keys = self._keys( - ray.gcs_utils.TablePrefix_PROFILE_string + "*") + profile_table_keys = self._keys(gcs_utils.TablePrefix_PROFILE_string + + "*") batch_identifiers_binary = [ - key[len(ray.gcs_utils.TablePrefix_PROFILE_string):] + key[len(gcs_utils.TablePrefix_PROFILE_string):] for key in profile_table_keys ] @@ -766,7 +735,7 @@ def cluster_resources(self): clients = self.client_table() for client in clients: # Only count resources from latest entries of live clients. - if client["EntryType"] != EntryType.DELETION: + if client["EntryType"] != gcs_utils.ClientTableData.DELETION: for key, value in client["Resources"].items(): resources[key] += value return dict(resources) @@ -776,7 +745,7 @@ def _live_client_ids(self): return { client["ClientID"] for client in self.client_table() - if (client["EntryType"] != EntryType.DELETION) + if (client["EntryType"] != gcs_utils.ClientTableData.DELETION) } def available_resources(self): @@ -800,7 +769,7 @@ def available_resources(self): for redis_client in self.redis_clients ] for subscribe_client in subscribe_clients: - subscribe_client.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL) + subscribe_client.subscribe(gcs_utils.XRAY_HEARTBEAT_CHANNEL) client_ids = self._live_client_ids() @@ -809,24 +778,23 @@ def available_resources(self): # Parse client message raw_message = subscribe_client.get_message() if (raw_message is None or raw_message["channel"] != - ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL): + gcs_utils.XRAY_HEARTBEAT_CHANNEL): continue data = raw_message["data"] - gcs_entries = (ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( - data, 0)) - heartbeat_data = gcs_entries.Entries(0) - message = (ray.gcs_utils.HeartbeatTableData. - GetRootAsHeartbeatTableData(heartbeat_data, 0)) + gcs_entries = gcs_utils.GcsEntry.FromString(data) + heartbeat_data = gcs_entries.entries[0] + message = gcs_utils.HeartbeatTableData.FromString( + heartbeat_data) # Calculate available resources for this client - num_resources = message.ResourcesAvailableLabelLength() + num_resources = len(message.resources_available_label) dynamic_resources = {} for i in range(num_resources): - resource_id = decode(message.ResourcesAvailableLabel(i)) + resource_id = message.resources_available_label[i] dynamic_resources[resource_id] = ( - message.ResourcesAvailableCapacity(i)) + message.resources_available_capacity[i]) # Update available resources for this client - client_id = ray.utils.binary_to_hex(message.ClientId()) + client_id = ray.utils.binary_to_hex(message.client_id) available_resources_by_id[client_id] = dynamic_resources # Update clients in cluster @@ -860,23 +828,22 @@ def _error_messages(self, driver_id): """ assert isinstance(driver_id, ray.DriverID) message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.ERROR_INFO, "", + "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ERROR_INFO"), "", driver_id.binary()) # If there are no errors, return early. if message is None: return [] - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entries = gcs_utils.GcsEntry.FromString(message) error_messages = [] - for i in range(gcs_entries.EntriesLength()): - error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( - gcs_entries.Entries(i), 0) - assert driver_id.binary() == error_data.DriverId() + for entry in gcs_entries.entries: + error_data = gcs_utils.ErrorTableData.FromString(entry) + assert driver_id.binary() == error_data.driver_id error_message = { - "type": decode(error_data.Type()), - "message": decode(error_data.ErrorMessage()), - "timestamp": error_data.Timestamp(), + "type": error_data.type, + "message": error_data.error_message, + "timestamp": error_data.timestamp, } error_messages.append(error_message) return error_messages @@ -899,9 +866,9 @@ def error_messages(self, driver_id=None): return self._error_messages(driver_id) error_table_keys = self.redis_client.keys( - ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*") + gcs_utils.TablePrefix_ERROR_INFO_string + "*") driver_ids = [ - key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):] + key[len(gcs_utils.TablePrefix_ERROR_INFO_string):] for key in error_table_keys ] @@ -923,30 +890,23 @@ def actor_checkpoint_info(self, actor_id): message = self._execute_command( actor_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.ACTOR_CHECKPOINT_ID, + gcs_utils.TablePrefix.Value("ACTOR_CHECKPOINT_ID"), "", actor_id.binary(), ) if message is None: return None - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) - entry = ( - ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData( - gcs_entry.Entries(0), 0)) - checkpoint_ids_str = entry.CheckpointIds() - num_checkpoints = len(checkpoint_ids_str) // ID_SIZE - assert len(checkpoint_ids_str) % ID_SIZE == 0 + gcs_entry = gcs_utils.GcsEntry.FromString(message) + entry = gcs_utils.ActorCheckpointIdData.FromString( + gcs_entry.entries[0]) checkpoint_ids = [ - ray.ActorCheckpointID( - checkpoint_ids_str[(i * ID_SIZE):((i + 1) * ID_SIZE)]) - for i in range(num_checkpoints) + ray.ActorCheckpointID(checkpoint_id) + for checkpoint_id in entry.checkpoint_ids ] return { - "ActorID": ray.utils.binary_to_hex(entry.ActorId()), + "ActorID": ray.utils.binary_to_hex(entry.actor_id), "CheckpointIds": checkpoint_ids, - "Timestamps": [ - entry.Timestamps(i) for i in range(num_checkpoints) - ], + "Timestamps": list(entry.timestamps), } diff --git a/python/ray/tests/cluster_utils.py b/python/ray/tests/cluster_utils.py index 703c3a1420ed9..76dfd3000b860 100644 --- a/python/ray/tests/cluster_utils.py +++ b/python/ray/tests/cluster_utils.py @@ -8,7 +8,7 @@ import redis import ray -from ray.core.generated.EntryType import EntryType +from ray.gcs_utils import ClientTableData logger = logging.getLogger(__name__) @@ -177,7 +177,7 @@ def wait_for_nodes(self, timeout=30): clients = ray.state._parse_client_table(redis_client) live_clients = [ client for client in clients - if client["EntryType"] == EntryType.INSERTION + if client["EntryType"] == ClientTableData.INSERTION ] expected = len(self.list_all_nodes()) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 7f1f78d1b5c41..6b4bd754cd4de 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -2736,15 +2736,17 @@ def test_duplicate_error_messages(shutdown_only): r = ray.worker.global_worker.redis_client - r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(), - error_data) + r.execute_command("RAY.TABLE_APPEND", + ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), + driver_id.binary(), error_data) # Before https://github.com/ray-project/ray/pull/3316 this would # give an error - r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(), - error_data) + r.execute_command("RAY.TABLE_APPEND", + ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), + driver_id.binary(), error_data) @pytest.mark.skipif( diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 51b906695c2d3..a560e461f7a21 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -493,8 +493,9 @@ def test_warning_monitor_died(shutdown_only): malformed_message = "asdf" redis_client = ray.worker.global_worker.redis_client redis_client.execute_command( - "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.HEARTBEAT_BATCH, - ray.gcs_utils.TablePubsub.HEARTBEAT_BATCH, fake_id, malformed_message) + "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.Value("HEARTBEAT_BATCH"), + ray.gcs_utils.TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB"), fake_id, + malformed_message) wait_for_errors(ray_constants.MONITOR_DIED_ERROR, 1) diff --git a/python/ray/utils.py b/python/ray/utils.py index 7b87486e325ee..0db48e41d025c 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -93,10 +93,10 @@ def push_error_to_driver_through_redis(redis_client, # of through the raylet. error_data = ray.gcs_utils.construct_error_message(driver_id, error_type, message, time.time()) - redis_client.execute_command("RAY.TABLE_APPEND", - ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, - driver_id.binary(), error_data) + redis_client.execute_command( + "RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), + driver_id.binary(), error_data) def is_cython(obj): diff --git a/python/ray/worker.py b/python/ray/worker.py index 7505120574a62..710f0db43c6b1 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -47,7 +47,7 @@ from ray import import_thread from ray import profiling -from ray.core.generated.ErrorType import ErrorType +from ray.gcs_utils import ErrorType from ray.exceptions import ( RayActorError, RayError, @@ -461,11 +461,11 @@ def _deserialize_object_from_arrow(self, data, metadata, object_id, # Otherwise, return an exception object based on # the error type. error_type = int(metadata) - if error_type == ErrorType.WORKER_DIED: + if error_type == ErrorType.Value("WORKER_DIED"): return RayWorkerError() - elif error_type == ErrorType.ACTOR_DIED: + elif error_type == ErrorType.Value("ACTOR_DIED"): return RayActorError() - elif error_type == ErrorType.OBJECT_UNRECONSTRUCTABLE: + elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"): return UnreconstructableError(ray.ObjectID(object_id.binary())) else: assert False, "Unrecognized error type " + str(error_type) @@ -1637,7 +1637,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): # Really we should just subscribe to the errors for this specific job. # However, currently all errors seem to be published on the same channel. error_pubsub_channel = str( - ray.gcs_utils.TablePubsub.ERROR_INFO).encode("ascii") + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB")).encode("ascii") worker.error_message_pubsub_client.subscribe(error_pubsub_channel) # worker.error_message_pubsub_client.psubscribe("*") @@ -1656,21 +1656,19 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): if msg is None: threads_stopped.wait(timeout=0.01) continue - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( - msg["data"], 0) - assert gcs_entry.EntriesLength() == 1 - error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( - gcs_entry.Entries(0), 0) - driver_id = error_data.DriverId() + gcs_entry = ray.gcs_utils.GcsEntry.FromString(msg["data"]) + assert len(gcs_entry.entries) == 1 + error_data = ray.gcs_utils.ErrorTableData.FromString( + gcs_entry.entries[0]) + driver_id = error_data.driver_id if driver_id not in [ worker.task_driver_id.binary(), DriverID.nil().binary() ]: continue - error_message = ray.utils.decode(error_data.ErrorMessage()) - if (ray.utils.decode( - error_data.Type()) == ray_constants.TASK_PUSH_ERROR): + error_message = error_data.error_message + if (error_data.type == ray_constants.TASK_PUSH_ERROR): # Delay it a bit to see if we can suppress it task_error_queue.put((error_message, time.time())) else: @@ -1878,14 +1876,16 @@ def connect(node, {}, # resource_map. {}, # placement_resource_map. ) + task_table_data = ray.gcs_utils.TaskTableData() + task_table_data.task = driver_task._serialized_raylet_task() # Add the driver task to the task table. - ray.state.state._execute_command(driver_task.task_id(), - "RAY.TABLE_ADD", - ray.gcs_utils.TablePrefix.RAYLET_TASK, - ray.gcs_utils.TablePubsub.RAYLET_TASK, - driver_task.task_id().binary(), - driver_task._serialized_raylet_task()) + ray.state.state._execute_command( + driver_task.task_id(), "RAY.TABLE_ADD", + ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"), + ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"), + driver_task.task_id().binary(), + task_table_data.SerializeToString()) # Set the driver's current task ID to the task ID assigned to the # driver task. diff --git a/python/setup.py b/python/setup.py index db8676042de93..e7cf14737ee23 100644 --- a/python/setup.py +++ b/python/setup.py @@ -150,6 +150,7 @@ def find_version(*filepath): "six >= 1.0.0", "flatbuffers", "faulthandler;python_version<'3.3'", + "protobuf", ] setup( diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index c9b1e138575d6..6de29bb52764f 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -206,10 +206,6 @@ TaskLeaseTable &AsyncGcsClient::task_lease_table() { return *task_lease_table_; ClientTable &AsyncGcsClient::client_table() { return *client_table_; } -FunctionTable &AsyncGcsClient::function_table() { return *function_table_; } - -ClassTable &AsyncGcsClient::class_table() { return *class_table_; } - HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; } HeartbeatBatchTable &AsyncGcsClient::heartbeat_batch_table() { diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index c9f5b4bca6249..5e70025b39a06 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -44,11 +44,7 @@ class RAY_EXPORT AsyncGcsClient { /// one event loop should be attached at a time. Status Attach(boost::asio::io_service &io_service); - inline FunctionTable &function_table(); // TODO: Some API for getting the error on the driver - inline ClassTable &class_table(); - inline CustomSerializerTable &custom_serializer_table(); - inline ConfigTable &config_table(); ObjectTable &object_table(); raylet::TaskTable &raylet_task_table(); ActorTable &actor_table(); @@ -81,8 +77,6 @@ class RAY_EXPORT AsyncGcsClient { std::string DebugString() const; private: - std::unique_ptr function_table_; - std::unique_ptr class_table_; std::unique_ptr object_table_; std::unique_ptr raylet_task_table_; std::unique_ptr actor_table_; diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index c7dc02e50651e..55115b1e20675 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -85,21 +85,21 @@ class TestGcsWithChainAsio : public TestGcsWithAsio { void TestTableLookup(const DriverID &driver_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); - auto data = std::make_shared(); - data->task_specification = "123"; + auto data = std::make_shared(); + data->set_task("123"); // Check that we added the correct task. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { + const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d.task_specification); + ASSERT_EQ(data->task(), d.task()); }; // Check that the lookup returns the added task. auto lookup_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { + const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d.task_specification); + ASSERT_EQ(data->task(), d.task()); test->Stop(); }; @@ -136,13 +136,13 @@ void TestLogLookup(const DriverID &driver_id, TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"abc", "def", "ghi"}; for (auto &node_manager_id : node_manager_ids) { - auto data = std::make_shared(); - data->node_manager_id = node_manager_id; + auto data = std::make_shared(); + data->set_node_manager_id(node_manager_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionDataT &d) { + const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->node_manager_id, d.node_manager_id); + ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); }; RAY_CHECK_OK( client->task_reconstruction_log().Append(driver_id, task_id, data, add_callback)); @@ -151,10 +151,10 @@ void TestLogLookup(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [task_id, node_manager_ids]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); for (const auto &entry : data) { - ASSERT_EQ(entry.node_manager_id, node_manager_ids[test->NumCallbacks()]); + ASSERT_EQ(entry.node_manager_id(), node_manager_ids[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == node_manager_ids.size()) { @@ -182,7 +182,7 @@ void TestTableLookupFailure(const DriverID &driver_id, // Check that the lookup does not return data. auto lookup_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { RAY_CHECK(false); }; + const TaskTableData &d) { RAY_CHECK(false); }; // Check that the lookup returns an empty entry. auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id) { @@ -207,16 +207,16 @@ void TestLogAppendAt(const DriverID &driver_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"A", "B"}; - std::vector> data_log; + std::vector> data_log; for (const auto &node_manager_id : node_manager_ids) { - auto data = std::make_shared(); - data->node_manager_id = node_manager_id; + auto data = std::make_shared(); + data->set_node_manager_id(node_manager_id); data_log.push_back(data); } // Check that we added the correct task. auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionDataT &d) { + const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); }; @@ -242,10 +242,10 @@ void TestLogAppendAt(const DriverID &driver_id, auto lookup_callback = [node_manager_ids]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { std::vector appended_managers; for (const auto &entry : data) { - appended_managers.push_back(entry.node_manager_id); + appended_managers.push_back(entry.node_manager_id()); } ASSERT_EQ(appended_managers, node_manager_ids); test->Stop(); @@ -268,22 +268,22 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"abc", "def", "ghi"}; for (auto &manager : managers) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); // Check that we added the correct object entries. auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableDataT &d) { + const ObjectTableData &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback)); } // Check that lookup returns the added object entries. - auto lookup_callback = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + auto lookup_callback = [object_id, managers](gcs::AsyncGcsClient *client, + const ObjectID &id, + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), managers.size()); test->IncrementNumCallbacks(); @@ -293,14 +293,14 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback)); for (auto &manager : managers) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); // Check that we added the correct object entries. auto remove_entry_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableDataT &d) { + const ObjectTableData &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK( @@ -310,7 +310,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli // Check that the entries are removed. auto lookup_callback2 = [object_id, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 0); test->IncrementNumCallbacks(); @@ -332,7 +332,7 @@ TEST_F(TestGcsWithAsio, TestSet) { void TestDeleteKeysFromLog( const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector) { + std::vector> &data_vector) { std::vector ids; TaskID task_id; for (auto &data : data_vector) { @@ -340,9 +340,9 @@ void TestDeleteKeysFromLog( ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionDataT &d) { + const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->node_manager_id, d.node_manager_id); + ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK( @@ -352,7 +352,7 @@ void TestDeleteKeysFromLog( // Check that lookup returns the added object entries. auto lookup_callback = [task_id, data_vector]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -367,7 +367,7 @@ void TestDeleteKeysFromLog( } for (const auto &task_id : ids) { auto lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_TRUE(data.size() == 0); test->IncrementNumCallbacks(); @@ -379,7 +379,7 @@ void TestDeleteKeysFromLog( void TestDeleteKeysFromTable(const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector, + std::vector> &data_vector, bool stop_at_end) { std::vector ids; TaskID task_id; @@ -388,16 +388,16 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { + const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d.task_specification); + ASSERT_EQ(data->task(), d.task()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, add_callback)); } for (const auto &task_id : ids) { auto task_lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { + const TaskTableData &data) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); }; @@ -414,7 +414,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, test->IncrementNumCallbacks(); }; auto undesired_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { ASSERT_TRUE(false); }; + const TaskTableData &data) { ASSERT_TRUE(false); }; for (size_t i = 0; i < ids.size(); ++i) { RAY_CHECK_OK(client->raylet_task_table().Lookup( driver_id, task_id, undesired_callback, expected_failure_callback)); @@ -428,7 +428,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, void TestDeleteKeysFromSet(const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector) { + std::vector> &data_vector) { std::vector ids; ObjectID object_id; for (auto &data : data_vector) { @@ -436,9 +436,9 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, ids.push_back(object_id); // Check that we added the correct object entries. auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableDataT &d) { + const ObjectTableData &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback)); @@ -447,7 +447,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [object_id, data_vector]( gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -461,7 +461,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, } for (const auto &object_id : ids) { auto lookup_callback = [object_id](gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_TRUE(data.size() == 0); test->IncrementNumCallbacks(); @@ -474,11 +474,11 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, void TestDeleteKeys(const DriverID &driver_id, std::shared_ptr client) { // Test delete function for keys of Log. - std::vector> task_reconstruction_vector; + std::vector> task_reconstruction_vector; auto AppendTaskReconstructionData = [&task_reconstruction_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto data = std::make_shared(); - data->node_manager_id = ObjectID::FromRandom().Hex(); + auto data = std::make_shared(); + data->set_node_manager_id(ObjectID::FromRandom().Hex()); task_reconstruction_vector.push_back(data); } }; @@ -503,11 +503,11 @@ void TestDeleteKeys(const DriverID &driver_id, TestDeleteKeysFromLog(driver_id, client, task_reconstruction_vector); // Test delete function for keys of Table. - std::vector> task_vector; + std::vector> task_vector; auto AppendTaskData = [&task_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto task_data = std::make_shared(); - task_data->task_specification = ObjectID::FromRandom().Hex(); + auto task_data = std::make_shared(); + task_data->set_task(ObjectID::FromRandom().Hex()); task_vector.push_back(task_data); } }; @@ -529,11 +529,11 @@ void TestDeleteKeys(const DriverID &driver_id, 9 * RayConfig::instance().maximum_gcs_deletion_batch_size()); // Test delete function for keys of Set. - std::vector> object_vector; + std::vector> object_vector; auto AppendObjectData = [&object_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto data = std::make_shared(); - data->manager = ObjectID::FromRandom().Hex(); + auto data = std::make_shared(); + data->set_manager(ObjectID::FromRandom().Hex()); object_vector.push_back(data); } }; @@ -561,45 +561,6 @@ TEST_F(TestGcsWithAsio, TestDeleteKey) { TestDeleteKeys(driver_id_, client_); } -// Task table callbacks. -void TaskAdded(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - ASSERT_EQ(data.scheduling_state, SchedulingState::SCHEDULED); - ASSERT_EQ(data.raylet_id, kRandomId); -} - -void TaskLookupHelper(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data, bool do_stop) { - ASSERT_EQ(data.scheduling_state, SchedulingState::SCHEDULED); - ASSERT_EQ(data.raylet_id, kRandomId); - if (do_stop) { - test->Stop(); - } -} -void TaskLookup(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - TaskLookupHelper(client, id, data, /*do_stop=*/false); -} -void TaskLookupWithStop(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - TaskLookupHelper(client, id, data, /*do_stop=*/true); -} - -void TaskLookupFailure(gcs::AsyncGcsClient *client, const TaskID &id) { - RAY_CHECK(false); -} - -void TaskLookupAfterUpdate(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - ASSERT_EQ(data.scheduling_state, SchedulingState::LOST); - test->Stop(); -} - -void TaskLookupAfterUpdateFailure(gcs::AsyncGcsClient *client, const TaskID &id) { - RAY_CHECK(false); - test->Stop(); -} - void TestLogSubscribeAll(const DriverID &driver_id, std::shared_ptr client) { std::vector driver_ids; @@ -609,11 +570,11 @@ void TestLogSubscribeAll(const DriverID &driver_id, // Callback for a notification. auto notification_callback = [driver_ids](gcs::AsyncGcsClient *client, const DriverID &id, - const std::vector data) { + const std::vector data) { ASSERT_EQ(id, driver_ids[test->NumCallbacks()]); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids[test->NumCallbacks()].Binary()); + ASSERT_EQ(entry.driver_id(), driver_ids[test->NumCallbacks()].Binary()); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids.size()) { @@ -660,7 +621,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, auto notification_callback = [object_ids, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector data) { + const std::vector data) { if (test->NumCallbacks() < 3 * 3) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); } else { @@ -669,7 +630,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.manager, managers[test->NumCallbacks() % 3]); + ASSERT_EQ(entry.manager(), managers[test->NumCallbacks() % 3]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == object_ids.size() * 3 * 2) { @@ -684,8 +645,8 @@ void TestSetSubscribeAll(const DriverID &driver_id, // We have subscribed. Do the writes to the table. for (size_t i = 0; i < object_ids.size(); i++) { for (size_t j = 0; j < managers.size(); j++) { - auto data = std::make_shared(); - data->manager = managers[j]; + auto data = std::make_shared(); + data->set_manager(managers[j]); for (int k = 0; k < 3; k++) { // Add the same entry several times. // Expect no notification if the entry already exists. @@ -696,8 +657,8 @@ void TestSetSubscribeAll(const DriverID &driver_id, } for (size_t i = 0; i < object_ids.size(); i++) { for (size_t j = 0; j < managers.size(); j++) { - auto data = std::make_shared(); - data->manager = managers[j]; + auto data = std::make_shared(); + data->set_manager(managers[j]); for (int k = 0; k < 3; k++) { // Remove the same entry several times. // Expect no notification if the entry doesn't exist. @@ -740,11 +701,11 @@ void TestTableSubscribeId(const DriverID &driver_id, // received for keys that we requested notifications for. auto notification_callback = [task_id2, task_specs2](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { + const TaskTableData &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, task_id2); // Check that we get notifications in the same order as the writes. - ASSERT_EQ(data.task_specification, task_specs2[test->NumCallbacks()]); + ASSERT_EQ(data.task(), task_specs2[test->NumCallbacks()]); test->IncrementNumCallbacks(); if (test->NumCallbacks() == task_specs2.size()) { test->Stop(); @@ -771,13 +732,13 @@ void TestTableSubscribeId(const DriverID &driver_id, // Write both keys. We should only receive notifications for the key that // we requested them for. for (const auto &task_spec : task_specs1) { - auto data = std::make_shared(); - data->task_specification = task_spec; + auto data = std::make_shared(); + data->set_task(task_spec); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id1, data, nullptr)); } for (const auto &task_spec : task_specs2) { - auto data = std::make_shared(); - data->task_specification = task_spec; + auto data = std::make_shared(); + data->set_task(task_spec); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id2, data, nullptr)); } }; @@ -808,27 +769,27 @@ void TestLogSubscribeId(const DriverID &driver_id, // Add a log entry. DriverID driver_id1 = DriverID::FromRandom(); std::vector driver_ids1 = {"abc", "def", "ghi"}; - auto data1 = std::make_shared(); - data1->driver_id = driver_ids1[0]; + auto data1 = std::make_shared(); + data1->set_driver_id(driver_ids1[0]); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data1, nullptr)); // Add a log entry at a second key. DriverID driver_id2 = DriverID::FromRandom(); std::vector driver_ids2 = {"jkl", "mno", "pqr"}; - auto data2 = std::make_shared(); - data2->driver_id = driver_ids2[0]; + auto data2 = std::make_shared(); + data2->set_driver_id(driver_ids2[0]); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data2, nullptr)); // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [driver_id2, driver_ids2]( gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + const std::vector &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, driver_id2); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids2[test->NumCallbacks()]); + ASSERT_EQ(entry.driver_id(), driver_ids2[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids2.size()) { @@ -847,14 +808,14 @@ void TestLogSubscribeId(const DriverID &driver_id, // we requested them for. auto remaining = std::vector(++driver_ids1.begin(), driver_ids1.end()); for (const auto &driver_id_it : remaining) { - auto data = std::make_shared(); - data->driver_id = driver_id_it; + auto data = std::make_shared(); + data->set_driver_id(driver_id_it); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data, nullptr)); } remaining = std::vector(++driver_ids2.begin(), driver_ids2.end()); for (const auto &driver_id_it : remaining) { - auto data = std::make_shared(); - data->driver_id = driver_id_it; + auto data = std::make_shared(); + data->set_driver_id(driver_id_it); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data, nullptr)); } }; @@ -882,15 +843,15 @@ void TestSetSubscribeId(const DriverID &driver_id, // Add a set entry. ObjectID object_id1 = ObjectID::FromRandom(); std::vector managers1 = {"abc", "def", "ghi"}; - auto data1 = std::make_shared(); - data1->manager = managers1[0]; + auto data1 = std::make_shared(); + data1->set_manager(managers1[0]); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data1, nullptr)); // Add a set entry at a second key. ObjectID object_id2 = ObjectID::FromRandom(); std::vector managers2 = {"jkl", "mno", "pqr"}; - auto data2 = std::make_shared(); - data2->manager = managers2[0]; + auto data2 = std::make_shared(); + data2->set_manager(managers2[0]); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data2, nullptr)); // The callback for a notification from the table. This should only be @@ -898,13 +859,13 @@ void TestSetSubscribeId(const DriverID &driver_id, auto notification_callback = [object_id2, managers2]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); // Check that we only get notifications for the requested key. ASSERT_EQ(id, object_id2); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.manager, managers2[test->NumCallbacks()]); + ASSERT_EQ(entry.manager(), managers2[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == managers2.size()) { @@ -923,14 +884,14 @@ void TestSetSubscribeId(const DriverID &driver_id, // we requested them for. auto remaining = std::vector(++managers1.begin(), managers1.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data, nullptr)); } remaining = std::vector(++managers2.begin(), managers2.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data, nullptr)); } }; @@ -958,8 +919,8 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // Add a table entry. TaskID task_id = TaskID::FromRandom(); std::vector task_specs = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->task_specification = task_specs[0]; + auto data = std::make_shared(); + data->set_task(task_specs[0]); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr)); // The failure callback should not be called since all keys are non-empty @@ -972,14 +933,14 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // received for keys that we requested notifications for. auto notification_callback = [task_id, task_specs](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { + const TaskTableData &data) { ASSERT_EQ(id, task_id); // Check that we only get notifications for the first and last writes, // since notifications are canceled in between. if (test->NumCallbacks() == 0) { - ASSERT_EQ(data.task_specification, task_specs.front()); + ASSERT_EQ(data.task(), task_specs.front()); } else { - ASSERT_EQ(data.task_specification, task_specs.back()); + ASSERT_EQ(data.task(), task_specs.back()); } test->IncrementNumCallbacks(); if (test->NumCallbacks() == 2) { @@ -1001,8 +962,8 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // a notification for these writes. auto remaining = std::vector(++task_specs.begin(), task_specs.end()); for (const auto &task_spec : remaining) { - auto data = std::make_shared(); - data->task_specification = task_spec; + auto data = std::make_shared(); + data->set_task(task_spec); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr)); } // Request notifications again. We should receive a notification for the @@ -1034,15 +995,15 @@ void TestLogSubscribeCancel(const DriverID &driver_id, // Add a log entry. DriverID random_driver_id = DriverID::FromRandom(); std::vector driver_ids = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->driver_id = driver_ids[0]; + auto data = std::make_shared(); + data->set_driver_id(driver_ids[0]); RAY_CHECK_OK(client->driver_table().Append(driver_id, random_driver_id, data, nullptr)); // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [random_driver_id, driver_ids]( gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, random_driver_id); // Check that we get a duplicate notification for the first write. We get a // duplicate notification because the log is append-only and notifications @@ -1050,7 +1011,7 @@ void TestLogSubscribeCancel(const DriverID &driver_id, auto driver_ids_copy = driver_ids; driver_ids_copy.insert(driver_ids_copy.begin(), driver_ids_copy.front()); for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids_copy[test->NumCallbacks()]); + ASSERT_EQ(entry.driver_id(), driver_ids_copy[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids_copy.size()) { @@ -1072,8 +1033,8 @@ void TestLogSubscribeCancel(const DriverID &driver_id, // receive a notification for these writes. auto remaining = std::vector(++driver_ids.begin(), driver_ids.end()); for (const auto &remaining_driver_id : remaining) { - auto data = std::make_shared(); - data->driver_id = remaining_driver_id; + auto data = std::make_shared(); + data->set_driver_id(remaining_driver_id); RAY_CHECK_OK( client->driver_table().Append(driver_id, random_driver_id, data, nullptr)); } @@ -1107,8 +1068,8 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // Add a set entry. ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->manager = managers[0]; + auto data = std::make_shared(); + data->set_manager(managers[0]); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr)); // The callback for a notification from the object table. This should only be @@ -1116,7 +1077,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, auto notification_callback = [object_id, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); ASSERT_EQ(id, object_id); // Check that we get a duplicate notification for the first write. We get a @@ -1124,7 +1085,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // are canceled after the first write, then requested again. if (data.size() == 1) { // first notification - ASSERT_EQ(data[0].manager, managers[0]); + ASSERT_EQ(data[0].manager(), managers[0]); test->IncrementNumCallbacks(); } else { // second notification @@ -1132,7 +1093,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, std::unordered_set managers_set(managers.begin(), managers.end()); std::unordered_set data_managers_set; for (const auto &entry : data) { - data_managers_set.insert(entry.manager); + data_managers_set.insert(entry.manager()); test->IncrementNumCallbacks(); } ASSERT_EQ(managers_set, data_managers_set); @@ -1156,8 +1117,8 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // receive a notification for these writes. auto remaining = std::vector(++managers.begin(), managers.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr)); } // Request notifications again. We should receive a notification for the @@ -1186,17 +1147,17 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeCancel) { } void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client_id, - const ClientTableDataT &data, bool is_insertion) { + const ClientTableData &data, bool is_insertion) { ClientID added_id = client->client_table().GetLocalClientId(); ASSERT_EQ(client_id, added_id); - ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id); - ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id); - ASSERT_EQ(data.entry_type == EntryType::INSERTION, is_insertion); + ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id); + ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id); + ASSERT_EQ(data.entry_type() == ClientTableData::INSERTION, is_insertion); - ClientTableDataT cached_client; + ClientTableData cached_client; client->client_table().GetClient(added_id, cached_client); - ASSERT_EQ(ClientID::FromBinary(cached_client.client_id), added_id); - ASSERT_EQ(cached_client.entry_type == EntryType::INSERTION, is_insertion); + ASSERT_EQ(ClientID::FromBinary(cached_client.client_id()), added_id); + ASSERT_EQ(cached_client.entry_type() == ClientTableData::INSERTION, is_insertion); } void TestClientTableConnect(const DriverID &driver_id, @@ -1204,17 +1165,17 @@ void TestClientTableConnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, true); test->Stop(); }); // Connect and disconnect to client table. We should receive notifications // for the addition and removal of our own entry. - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); RAY_CHECK_OK(client->client_table().Connect(local_client_info)); test->Start(); } @@ -1229,23 +1190,23 @@ void TestClientTableDisconnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, /*is_insertion=*/true); // Disconnect from the client table. We should receive a notification // for the removal of our own entry. RAY_CHECK_OK(client->client_table().Disconnect()); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, /*is_insertion=*/false); test->Stop(); }); // Connect to the client table. We should receive notification for the // addition of our own entry. - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); RAY_CHECK_OK(client->client_table().Connect(local_client_info)); test->Start(); } @@ -1260,20 +1221,20 @@ void TestClientTableImmediateDisconnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, true); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, false); test->Stop(); }); // Connect to then immediately disconnect from the client table. We should // receive notifications for the addition and removal of our own entry. - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); RAY_CHECK_OK(client->client_table().Connect(local_client_info)); RAY_CHECK_OK(client->client_table().Disconnect()); test->Start(); @@ -1286,10 +1247,10 @@ TEST_F(TestGcsWithAsio, TestClientTableImmediateDisconnect) { void TestClientTableMarkDisconnected(const DriverID &driver_id, std::shared_ptr client) { - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); // Connect to the client table to start receiving notifications. RAY_CHECK_OK(client->client_table().Connect(local_client_info)); // Mark a different client as dead. @@ -1299,8 +1260,8 @@ void TestClientTableMarkDisconnected(const DriverID &driver_id, // marked as dead. client->client_table().RegisterClientRemovedCallback( [dead_client_id](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { - ASSERT_EQ(ClientID::FromBinary(data.client_id), dead_client_id); + const ClientTableData &data) { + ASSERT_EQ(ClientID::FromBinary(data.client_id()), dead_client_id); test->Stop(); }); test->Start(); @@ -1316,31 +1277,31 @@ void TestHashTable(const DriverID &driver_id, const int expected_count = 14; ClientID client_id = ClientID::FromRandom(); // Prepare the first resource map: data_map1. - auto cpu_data = std::make_shared(); - cpu_data->resource_name = "CPU"; - cpu_data->resource_capacity = 100; - auto gpu_data = std::make_shared(); - gpu_data->resource_name = "GPU"; - gpu_data->resource_capacity = 2; + auto cpu_data = std::make_shared(); + cpu_data->set_resource_name("CPU"); + cpu_data->set_resource_capacity(100); + auto gpu_data = std::make_shared(); + gpu_data->set_resource_name("GPU"); + gpu_data->set_resource_capacity(2); DynamicResourceTable::DataMap data_map1; data_map1.emplace("CPU", cpu_data); data_map1.emplace("GPU", gpu_data); // Prepare the second resource map: data_map2 which decreases CPU, // increases GPU and add a new CUSTOM compared to data_map1. - auto data_cpu = std::make_shared(); - data_cpu->resource_name = "CPU"; - data_cpu->resource_capacity = 50; - auto data_gpu = std::make_shared(); - data_gpu->resource_name = "GPU"; - data_gpu->resource_capacity = 10; - auto data_custom = std::make_shared(); - data_custom->resource_name = "CUSTOM"; - data_custom->resource_capacity = 2; + auto data_cpu = std::make_shared(); + data_cpu->set_resource_name("CPU"); + data_cpu->set_resource_capacity(50); + auto data_gpu = std::make_shared(); + data_gpu->set_resource_name("GPU"); + data_gpu->set_resource_capacity(10); + auto data_custom = std::make_shared(); + data_custom->set_resource_name("CUSTOM"); + data_custom->set_resource_capacity(2); DynamicResourceTable::DataMap data_map2; data_map2.emplace("CPU", data_cpu); data_map2.emplace("GPU", data_gpu); data_map2.emplace("CUSTOM", data_custom); - data_map2["CPU"]->resource_capacity = 50; + data_map2["CPU"]->set_resource_capacity(50); // This is a common comparison function for the test. auto compare_test = [](const DynamicResourceTable::DataMap &data1, const DynamicResourceTable::DataMap &data2) { @@ -1348,8 +1309,8 @@ void TestHashTable(const DriverID &driver_id, for (const auto &data : data1) { auto iter = data2.find(data.first); ASSERT_TRUE(iter != data2.end()); - ASSERT_EQ(iter->second->resource_name, data.second->resource_name); - ASSERT_EQ(iter->second->resource_capacity, data.second->resource_capacity); + ASSERT_EQ(iter->second->resource_name(), data.second->resource_name()); + ASSERT_EQ(iter->second->resource_capacity(), data.second->resource_capacity()); } }; auto subscribe_callback = [](AsyncGcsClient *client) { diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 90476da734257..c06c79a029288 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -1,52 +1,9 @@ -enum Language:int { - PYTHON = 0, - CPP = 1, - JAVA = 2 -} - -// These indexes are mapped to strings in ray_redis_module.cc. -enum TablePrefix:int { - UNUSED = 0, - TASK, - RAYLET_TASK, - CLIENT, - OBJECT, - ACTOR, - FUNCTION, - TASK_RECONSTRUCTION, - HEARTBEAT, - HEARTBEAT_BATCH, - ERROR_INFO, - DRIVER, - PROFILE, - TASK_LEASE, - ACTOR_CHECKPOINT, - ACTOR_CHECKPOINT_ID, - NODE_RESOURCE, -} +// TODO(hchen): Migrate data structures in this file to protobuf (`gcs.proto`). -// The channel that Add operations to the Table should be published on, if any. -enum TablePubsub:int { - NO_PUBLISH = 0, - TASK, - RAYLET_TASK, - CLIENT, - OBJECT, - ACTOR, - HEARTBEAT, - HEARTBEAT_BATCH, - ERROR_INFO, - TASK_LEASE, - DRIVER, - NODE_RESOURCE, -} - -// Enum for the entry type in the ClientTable -enum EntryType:int { - INSERTION = 0, - DELETION, - RES_CREATEUPDATE, - RES_DELETE, +enum Language:int { + PYTHON=0, + JAVA=1, + CPP=2, } table Arg { @@ -120,118 +77,6 @@ table ResourcePair { value: double; } -enum GcsChangeMode:int { - APPEND_OR_ADD = 0, - REMOVE, -} - -table GcsEntry { - change_mode: GcsChangeMode; - id: string; - entries: [string]; -} - -table FunctionTableData { - language: Language; - name: string; - data: string; -} - -table ObjectTableData { - // The size of the object. - object_size: long; - // The node manager ID that this object appeared on or was evicted by. - manager: string; -} - -table TaskReconstructionData { - // The number of times this task has been reconstructed so far. - num_reconstructions: int; - // The node manager that is trying to reconstruct the task. - node_manager_id: string; -} - -enum SchedulingState:int { - NONE = 0, - WAITING = 1, - SCHEDULED = 2, - QUEUED = 4, - RUNNING = 8, - DONE = 16, - LOST = 32, - RECONSTRUCTING = 64 -} - -table TaskTableData { - // The state of the task. - scheduling_state: SchedulingState; - // A raylet ID. - raylet_id: string; - // A string of bytes representing the task's TaskExecutionDependencies. - execution_dependencies: string; - // The number of times the task was spilled back by raylets. - spillback_count: long; - // A string of bytes representing the task specification. - task_info: string; - // TODO(pcm): This is at the moment duplicated in task_info, remove that one - updated: bool; -} - -table TaskTableTestAndUpdate { - test_raylet_id: string; - test_state_bitmask: SchedulingState; - update_state: SchedulingState; -} - -table ClassTableData { -} - -enum ActorState:int { - // Actor is alive. - ALIVE = 0, - // Actor is dead, now being reconstructed. - // After reconstruction finishes, the state will become alive again. - RECONSTRUCTING = 1, - // Actor is already dead and won't be reconstructed. - DEAD = 2 -} - -table ActorTableData { - // The ID of the actor that was created. - actor_id: string; - // The dummy object ID returned by the actor creation task. If the actor - // dies, then this is the object that should be reconstructed for the actor - // to be recreated. - actor_creation_dummy_object_id: string; - // The ID of the driver that created the actor. - driver_id: string; - // The ID of the node manager that created the actor. - node_manager_id: string; - // Current state of this actor. - state: ActorState; - // Max number of times this actor should be reconstructed. - max_reconstructions: int; - // Remaining number of reconstructions. - remaining_reconstructions: int; -} - -table ErrorTableData { - // The ID of the driver that the error is for. - driver_id: string; - // The type of the error. - type: string; - // The error message. - error_message: string; - // The timestamp of the error message. - timestamp: double; -} - -table CustomSerializerData { -} - -table ConfigTableData { -} - table ProfileEvent { // The type of the event. event_type: string; @@ -258,119 +103,3 @@ table ProfileTableData { // we don't want each event to require a GCS command. profile_events: [ProfileEvent]; } - -table RayResource { - // The type of the resource. - resource_name: string; - // The total capacity of this resource type. - resource_capacity: double; -} - -table ClientTableData { - // The client ID of the client that the message is about. - client_id: string; - // The IP address of the client's node manager. - node_manager_address: string; - // The IPC socket name of the client's raylet. - raylet_socket_name: string; - // The IPC socket name of the client's plasma store. - object_store_socket_name: string; - // The port at which the client's node manager is listening for TCP - // connections from other node managers. - node_manager_port: int; - // The port at which the client's object manager is listening for TCP - // connections from other object managers. - object_manager_port: int; - // Enum to store the entry type in the log - entry_type: EntryType = INSERTION; - resources_total_label: [string]; - resources_total_capacity: [double]; -} - -table HeartbeatTableData { - // Node manager client id - client_id: string; - // Resource capacity currently available on this node manager. - resources_available_label: [string]; - resources_available_capacity: [double]; - // Total resource capacity configured for this node manager. - resources_total_label: [string]; - resources_total_capacity: [double]; - // Aggregate outstanding resource load on this node manager. - resource_load_label: [string]; - resource_load_capacity: [double]; -} - -table HeartbeatBatchTableData { - batch: [HeartbeatTableData]; -} - -// Data for a lease on task execution. -table TaskLeaseData { - // Node manager client ID. - node_manager_id: string; - // The time that the lease was last acquired at. NOTE(swang): This is the - // system clock time according to the node that added the entry and is not - // synchronized with other nodes. - acquired_at: long; - // The period that the lease is active for. - timeout: long; -} - -table DriverTableData { - // The driver ID. - driver_id: string; - // Whether it's dead. - is_dead: bool; -} - -// This table stores the actor checkpoint data. An actor checkpoint -// is the snapshot of an actor's state in the actor registration. -// See `actor_registration.h` for more detailed explanation of these fields. -table ActorCheckpointData { - // ID of this actor. - actor_id: string; - // The dummy object ID of actor's most recently executed task. - execution_dependency: string; - // A list of IDs of this actor's handles. - handle_ids: [string]; - // The task counters of the above handles. - task_counters: [long]; - // The frontier dependencies of the above handles. - frontier_dependencies: [string]; - // A list of unreleased dummy objects from this actor. - unreleased_dummy_objects: [string]; - // The numbers of dependencies for the above unreleased dummy objects. - num_dummy_object_dependencies: [int]; -} - -// This table stores the actor-to-available-checkpoint-ids mapping. -table ActorCheckpointIdData { - // ID of this actor. - actor_id: string; - // IDs of this actor's available checkpoints. - // Note, this is a long string that concatenates all the IDs. - checkpoint_ids: string; - // A list of the timestamps for each of the above `checkpoint_ids`. - timestamps: [long]; -} - -// This enum type is used as object's metadata to indicate the object's creating -// task has failed because of a certain error. -// TODO(hchen): We may want to make these errors more specific. E.g., we may want -// to distinguish between intentional and expected actor failures, and between -// worker process failure and node failure. -enum ErrorType:int { - // Indicates that a task failed because the worker died unexpectedly while executing it. - WORKER_DIED = 1, - // Indicates that a task failed because the actor died unexpectedly before finishing it. - ACTOR_DIED = 2, - // Indicates that an object is lost and cannot be reconstructed. - // Note, this currently only happens to actor objects. When the actor's state is already - // after the object's creating task, the actor cannot re-run the task. - // TODO(hchen): we may want to reuse this error type for more cases. E.g., - // 1) A object that was put by the driver. - // 2) The object's creating task is already cleaned up from GCS (this currently - // crashes raylet). - OBJECT_UNRECONSTRUCTABLE = 3, -} diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index fc42e5cd98c20..093aab2455d99 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -9,7 +9,7 @@ #include "ray/common/status.h" #include "ray/util/logging.h" -#include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/gcs.pb.h" extern "C" { #include "ray/thirdparty/hiredis/adapters/ae.h" @@ -25,6 +25,9 @@ namespace ray { namespace gcs { +using rpc::TablePrefix; +using rpc::TablePubsub; + /// A simple reply wrapper for redis reply. class CallbackReply { public: @@ -126,8 +129,8 @@ class RedisContext { /// -1 for unused. If set, then data must be provided. /// \return Status. template - Status RunAsync(const std::string &command, const ID &id, const uint8_t *data, - int64_t length, const TablePrefix prefix, + Status RunAsync(const std::string &command, const ID &id, const void *data, + size_t length, const TablePrefix prefix, const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length = -1); @@ -157,9 +160,9 @@ class RedisContext { }; template -Status RedisContext::RunAsync(const std::string &command, const ID &id, - const uint8_t *data, int64_t length, - const TablePrefix prefix, const TablePubsub pubsub_channel, +Status RedisContext::RunAsync(const std::string &command, const ID &id, const void *data, + size_t length, const TablePrefix prefix, + const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length) { int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false); if (length > 0) { diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index e291b7ffdb322..c3a82c320d06f 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -5,11 +5,16 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/gcs.pb.h" #include "ray/util/logging.h" #include "redis_string.h" #include "redismodule.h" using ray::Status; +using ray::rpc::GcsChangeMode; +using ray::rpc::GcsEntry; +using ray::rpc::TablePrefix; +using ray::rpc::TablePubsub; #if RAY_USE_NEW_GCS // Under this flag, ray-project/credis will be loaded. Specifically, via @@ -64,8 +69,8 @@ Status ParseTablePubsub(TablePubsub *out, const RedisModuleString *pubsub_channe REDISMODULE_OK) { return Status::RedisError("Pubsub channel must be a valid integer."); } - if (pubsub_channel_long > static_cast(TablePubsub::MAX) || - pubsub_channel_long < static_cast(TablePubsub::MIN)) { + if (pubsub_channel_long >= static_cast(TablePubsub::TABLE_PUBSUB_MAX) || + pubsub_channel_long <= static_cast(TablePubsub::TABLE_PUBSUB_MIN)) { return Status::RedisError("Pubsub channel must be in the TablePubsub range."); } else { *out = static_cast(pubsub_channel_long); @@ -80,7 +85,7 @@ Status FormatPubsubChannel(RedisModuleString **out, RedisModuleCtx *ctx, const RedisModuleString *id) { // Format the pubsub channel enum to a string. TablePubsub_MAX should be more // than enough digits, but add 1 just in case for the null terminator. - char pubsub_channel[static_cast(TablePubsub::MAX) + 1]; + char pubsub_channel[static_cast(TablePubsub::TABLE_PUBSUB_MAX) + 1]; TablePubsub table_pubsub; RAY_RETURN_NOT_OK(ParseTablePubsub(&table_pubsub, pubsub_channel_str)); sprintf(pubsub_channel, "%d", static_cast(table_pubsub)); @@ -95,8 +100,8 @@ Status ParseTablePrefix(const RedisModuleString *table_prefix_str, TablePrefix * REDISMODULE_OK) { return Status::RedisError("Prefix must be a valid TablePrefix integer"); } - if (table_prefix_long > static_cast(TablePrefix::MAX) || - table_prefix_long < static_cast(TablePrefix::MIN)) { + if (table_prefix_long >= static_cast(TablePrefix::TABLE_PREFIX_MAX) || + table_prefix_long <= static_cast(TablePrefix::TABLE_PREFIX_MIN)) { return Status::RedisError("Prefix must be in the TablePrefix range"); } else { *out = static_cast(table_prefix_long); @@ -113,7 +118,7 @@ RedisModuleString *PrefixedKeyString(RedisModuleCtx *ctx, RedisModuleString *pre if (!ParseTablePrefix(prefix_enum, &prefix).ok()) { return nullptr; } - return RedisString_Format(ctx, "%s%S", EnumNameTablePrefix(prefix), keyname); + return RedisString_Format(ctx, "%s%S", TablePrefix_Name(prefix).c_str(), keyname); } // TODO(swang): This helper function should be deprecated by the version below, @@ -136,8 +141,8 @@ Status OpenPrefixedKey(RedisModuleKey **out, RedisModuleCtx *ctx, int mode, RedisModuleString **mutated_key_str) { TablePrefix prefix; RAY_RETURN_NOT_OK(ParseTablePrefix(prefix_enum, &prefix)); - *out = - OpenPrefixedKey(ctx, EnumNameTablePrefix(prefix), keyname, mode, mutated_key_str); + *out = OpenPrefixedKey(ctx, TablePrefix_Name(prefix).c_str(), keyname, mode, + mutated_key_str); return Status::OK(); } @@ -165,18 +170,24 @@ Status GetBroadcastKey(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st return Status::OK(); } -/// This is a helper method to convert a redis module string to a flatbuffer -/// string. +/// A helper function that creates `GcsEntry` protobuf object. /// -/// \param fbb The flatbuffer builder. -/// \param redis_string The redis string. -/// \return The flatbuffer string. -flatbuffers::Offset RedisStringToFlatbuf( - flatbuffers::FlatBufferBuilder &fbb, RedisModuleString *redis_string) { - size_t redis_string_size; - const char *redis_string_str = - RedisModule_StringPtrLen(redis_string, &redis_string_size); - return fbb.CreateString(redis_string_str, redis_string_size); +/// \param[in] id Id of the entry. +/// \param[in] change_mode Change mode of the entry. +/// \param[in] entries Vector of entries. +/// \param[out] result The created `GcsEntry` object. +inline void CreateGcsEntry(RedisModuleString *id, GcsChangeMode change_mode, + const std::vector &entries, + GcsEntry *result) { + const char *data; + size_t size; + data = RedisModule_StringPtrLen(id, &size); + result->set_id(data, size); + result->set_change_mode(change_mode); + for (const auto &entry : entries) { + data = RedisModule_StringPtrLen(entry, &size); + result->add_entries(data, size); + } } /// Helper method to publish formatted data to target channel. @@ -234,13 +245,10 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st RedisModuleString *id, GcsChangeMode change_mode, RedisModuleString *data) { // Serialize the notification to send. - flatbuffers::FlatBufferBuilder fbb; - auto data_flatbuf = RedisStringToFlatbuf(fbb, data); - auto message = CreateGcsEntry(fbb, change_mode, RedisStringToFlatbuf(fbb, id), - fbb.CreateVector(&data_flatbuf, 1)); - fbb.Finish(message); - auto data_buffer = RedisModule_CreateString( - ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + GcsEntry gcs_entry; + CreateGcsEntry(id, change_mode, {data}, &gcs_entry); + std::string str = gcs_entry.SerializeAsString(); + auto data_buffer = RedisModule_CreateString(ctx, str.data(), str.size()); return PublishDataHelper(ctx, pubsub_channel_str, id, data_buffer); } @@ -570,19 +578,20 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, size_t update_data_len = 0; const char *update_data_buf = RedisModule_StringPtrLen(update_data, &update_data_len); - auto data_vec = flatbuffers::GetRoot(update_data_buf); - *change_mode = data_vec->change_mode(); + GcsEntry gcs_entry; + gcs_entry.ParseFromArray(update_data_buf, update_data_len); + *change_mode = gcs_entry.change_mode(); + if (*change_mode == GcsChangeMode::APPEND_OR_ADD) { // This code path means they are updating command. - size_t total_size = data_vec->entries()->size(); + size_t total_size = gcs_entry.entries_size(); REPLY_AND_RETURN_IF_FALSE(total_size % 2 == 0, "Invalid Hash Update data vector."); for (int i = 0; i < total_size; i += 2) { // Reconstruct a key-value pair from a flattened list. RedisModuleString *entry_key = RedisModule_CreateString( - ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); - RedisModuleString *entry_value = - RedisModule_CreateString(ctx, data_vec->entries()->Get(i + 1)->data(), - data_vec->entries()->Get(i + 1)->size()); + ctx, gcs_entry.entries(i).data(), gcs_entry.entries(i).size()); + RedisModuleString *entry_value = RedisModule_CreateString( + ctx, gcs_entry.entries(i + 1).data(), gcs_entry.entries(i + 1).size()); // Returning 0 if key exists(still updated), 1 if the key is created. RAY_IGNORE_EXPR( RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, entry_value, NULL)); @@ -590,27 +599,25 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, *changed_data = update_data; } else { // This code path means the command wants to remove the entries. - size_t total_size = data_vec->entries()->size(); - flatbuffers::FlatBufferBuilder fbb; - std::vector> data; + GcsEntry updated; + updated.set_id(gcs_entry.id()); + updated.set_change_mode(gcs_entry.change_mode()); + + size_t total_size = gcs_entry.entries_size(); for (int i = 0; i < total_size; i++) { RedisModuleString *entry_key = RedisModule_CreateString( - ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); + ctx, gcs_entry.entries(i).data(), gcs_entry.entries(i).size()); int deleted_num = RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, REDISMODULE_HASH_DELETE, NULL); if (deleted_num != 0) { // The corresponding key is removed. - data.push_back(fbb.CreateString(data_vec->entries()->Get(i)->data(), - data_vec->entries()->Get(i)->size())); + updated.add_entries(gcs_entry.entries(i)); } } - auto message = - CreateGcsEntry(fbb, data_vec->change_mode(), - fbb.CreateString(data_vec->id()->data(), data_vec->id()->size()), - fbb.CreateVector(data)); - fbb.Finish(message); - *changed_data = RedisModule_CreateString( - ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + + // Serialize updated data. + std::string str = updated.SerializeAsString(); + *changed_data = RedisModule_CreateString(ctx, str.data(), str.size()); auto size = RedisModule_ValueLength(key); if (size == 0) { REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(key) == REDISMODULE_OK, @@ -631,7 +638,7 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, /// key should be published to. When publishing to a specific client, the /// channel name should be :. /// \param id The ID of the key to remove from. -/// \param data The GcsEntry flatbugger data used to update this hash table. +/// \param data The GcsEntry protobuf data used to update this hash table. /// 1). For deletion, this is a list of keys. /// 2). For updating, this is a list of pairs with each key followed by the value. /// \return OK if the remove succeeds, or an error message string if the remove @@ -648,7 +655,7 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a return Hash_DoPublish(ctx, new_argv.data()); } -/// A helper function to create and finish a GcsEntry, based on the +/// A helper function to create a GcsEntry protobuf, based on the /// current value or values at the given key. /// /// \param ctx The Redis module context. @@ -658,21 +665,18 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a /// \param prefix_str The string prefix associated with the open Redis key. /// When parsed, this is expected to be a TablePrefix. /// \param entry_id The UniqueID associated with the open Redis key. -/// \param fbb A flatbuffer builder used to build the GcsEntry. -Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, - RedisModuleString *prefix_str, RedisModuleString *entry_id, - flatbuffers::FlatBufferBuilder &fbb) { +/// \param[out] gcs_entry The created GcsEntry. +Status TableEntryToProtobuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, + RedisModuleString *prefix_str, RedisModuleString *entry_id, + GcsEntry *gcs_entry) { auto key_type = RedisModule_KeyType(table_key); switch (key_type) { case REDISMODULE_KEYTYPE_STRING: { - // Build the flatbuffer from the string data. + // Build the GcsEntry from the string data. + CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); size_t data_len = 0; char *data_buf = RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ); - auto data = fbb.CreateString(data_buf, data_len); - auto message = - CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(&data, 1)); - fbb.Finish(message); + gcs_entry->add_entries(data_buf, data_len); } break; case REDISMODULE_KEYTYPE_LIST: case REDISMODULE_KEYTYPE_HASH: @@ -696,27 +700,20 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, reply = RedisModule_Call(ctx, "HGETALL", "s", table_key_str); break; } - // Build the flatbuffer from the set of log entries. + // Build the GcsEntry from the set of log entries. if (reply == nullptr || RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) { return Status::RedisError("Empty list/set/hash or wrong type"); } - std::vector> data; + CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); for (size_t i = 0; i < RedisModule_CallReplyLength(reply); i++) { RedisModuleCallReply *element = RedisModule_CallReplyArrayElement(reply, i); size_t len; const char *element_str = RedisModule_CallReplyStringPtr(element, &len); - data.push_back(fbb.CreateString(element_str, len)); + gcs_entry->add_entries(element_str, len); } - auto message = - CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data)); - fbb.Finish(message); } break; case REDISMODULE_KEYTYPE_EMPTY: { - auto message = CreateGcsEntry( - fbb, GcsChangeMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id), - fbb.CreateVector(std::vector>())); - fbb.Finish(message); + CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); } break; default: return Status::RedisError("Invalid Redis type during lookup."); @@ -752,11 +749,12 @@ int TableLookup_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int if (table_key == nullptr) { RedisModule_ReplyWithNull(ctx); } else { - // Serialize the data to a flatbuffer to return to the client. - flatbuffers::FlatBufferBuilder fbb; - REPLY_AND_RETURN_IF_NOT_OK(TableEntryToFlatbuf(ctx, table_key, prefix_str, id, fbb)); - RedisModule_ReplyWithStringBuffer( - ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + // Serialize the data to a GcsEntry to return to the client. + GcsEntry gcs_entry; + REPLY_AND_RETURN_IF_NOT_OK( + TableEntryToProtobuf(ctx, table_key, prefix_str, id, &gcs_entry)); + std::string str = gcs_entry.SerializeAsString(); + RedisModule_ReplyWithStringBuffer(ctx, str.data(), str.size()); } return REDISMODULE_OK; } @@ -870,10 +868,11 @@ int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx, RedisModuleStrin // Publish the current value at the key to the client that is requesting // notifications. An empty notification will be published if the key is // empty. - flatbuffers::FlatBufferBuilder fbb; - REPLY_AND_RETURN_IF_NOT_OK(TableEntryToFlatbuf(ctx, table_key, prefix_str, id, fbb)); - RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, - reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + GcsEntry gcs_entry; + REPLY_AND_RETURN_IF_NOT_OK( + TableEntryToProtobuf(ctx, table_key, prefix_str, id, &gcs_entry)); + std::string str = gcs_entry.SerializeAsString(); + RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, str.data(), str.size()); return RedisModule_ReplyWithNull(ctx); } @@ -940,53 +939,6 @@ Status IsNil(bool *out, const std::string &data) { return Status::OK(); } -// This is a temporary redis command that will be removed once -// the GCS uses https://github.com/pcmoritz/credis. -// Be careful, this only supports Task Table payloads. -int TableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, - int argc) { - if (argc != 5) { - return RedisModule_WrongArity(ctx); - } - RedisModuleString *prefix_str = argv[1]; - RedisModuleString *id = argv[3]; - RedisModuleString *update_data = argv[4]; - - RedisModuleKey *key; - REPLY_AND_RETURN_IF_NOT_OK( - OpenPrefixedKey(&key, ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE)); - - size_t value_len = 0; - char *value_buf = RedisModule_StringDMA(key, &value_len, REDISMODULE_READ); - - size_t update_len = 0; - const char *update_buf = RedisModule_StringPtrLen(update_data, &update_len); - - auto data = - flatbuffers::GetMutableRoot(reinterpret_cast(value_buf)); - - auto update = flatbuffers::GetRoot(update_buf); - - bool do_update = static_cast(data->scheduling_state()) & - static_cast(update->test_state_bitmask()); - - bool is_nil_result; - REPLY_AND_RETURN_IF_NOT_OK(IsNil(&is_nil_result, update->test_raylet_id()->str())); - if (!is_nil_result) { - do_update = do_update && update->test_raylet_id()->str() == data->raylet_id()->str(); - } - - if (do_update) { - REPLY_AND_RETURN_IF_FALSE(data->mutate_scheduling_state(update->update_state()), - "mutate_scheduling_state failed"); - } - REPLY_AND_RETURN_IF_FALSE(data->mutate_updated(do_update), "mutate_updated failed"); - - int result = RedisModule_ReplyWithStringBuffer(ctx, value_buf, value_len); - - return result; -} - std::string DebugString() { std::stringstream result; result << "RedisModule:"; @@ -1016,7 +968,6 @@ AUTO_MEMORY(TableLookup_RedisCommand); AUTO_MEMORY(TableRequestNotifications_RedisCommand); AUTO_MEMORY(TableDelete_RedisCommand); AUTO_MEMORY(TableCancelNotifications_RedisCommand); -AUTO_MEMORY(TableTestAndUpdate_RedisCommand); AUTO_MEMORY(DebugString_RedisCommand); #if RAY_USE_NEW_GCS AUTO_MEMORY(ChainTableAdd_RedisCommand); @@ -1082,12 +1033,6 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_ERR; } - if (RedisModule_CreateCommand(ctx, "ray.table_test_and_update", - TableTestAndUpdate_RedisCommand, "write", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - if (RedisModule_CreateCommand(ctx, "ray.debug_string", DebugString_RedisCommand, "readonly", 0, 0, 0) == REDISMODULE_ERR) { return REDISMODULE_ERR; diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 33f1615580a6a..b7c19ebfd595e 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -3,6 +3,7 @@ #include "ray/common/common_protocol.h" #include "ray/common/ray_config.h" #include "ray/gcs/client.h" +#include "ray/rpc/util.h" #include "ray/util/util.h" namespace { @@ -39,48 +40,44 @@ namespace gcs { template Status Log::Append(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_appends_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { const auto status = reply.ReadAsStatus(); // Failed to append the entry. RAY_CHECK(status.ok()) << "Failed to execute command TABLE_APPEND:" << status.ToString(); if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_, + std::move(callback)); } template Status Log::AppendAt(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done, + std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length) { num_appends_++; - auto callback = [this, id, dataT, done, failure](const CallbackReply &reply) { + auto callback = [this, id, data, done, failure](const CallbackReply &reply) { const auto status = reply.ReadAsStatus(); if (status.ok()) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } } else { if (failure != nullptr) { - (failure)(client_, id, *dataT); + (failure)(client_, id, *data); } } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback), log_length); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_, + std::move(callback), log_length); } template @@ -89,16 +86,15 @@ Status Log::Lookup(const DriverID &driver_id, const ID &id, num_lookups_++; auto callback = [this, id, lookup](const CallbackReply &reply) { if (lookup != nullptr) { - std::vector results; + std::vector results; if (!reply.IsNil()) { - const auto data = reply.ReadAsString(); - auto root = flatbuffers::GetRoot(data.data()); - RAY_CHECK(from_flatbuf(*root->id()) == id); - for (size_t i = 0; i < root->entries()->size(); i++) { - DataT result; - auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); - data_root->UnPackTo(&result); - results.emplace_back(std::move(result)); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(reply.ReadAsString()); + RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); + for (size_t i = 0; i < gcs_entry.entries_size(); i++) { + Data data; + data.ParseFromString(gcs_entry.entries(i)); + results.emplace_back(std::move(data)); } } lookup(client_, id, results); @@ -115,7 +111,7 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien const SubscriptionCallback &done) { auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { RAY_CHECK(change_mode != GcsChangeMode::REMOVE); subscribe(client, id, data); }; @@ -141,19 +137,16 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - auto root = flatbuffers::GetRoot(data.data()); - ID id; - if (root->id()->size() > 0) { - id = from_flatbuf(*root->id()); - } - std::vector results; - for (size_t i = 0; i < root->entries()->size(); i++) { - DataT result; - auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); - data_root->UnPackTo(&result); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(data); + ID id = ID::FromBinary(gcs_entry.id()); + std::vector results; + for (size_t i = 0; i < gcs_entry.entries_size(); i++) { + Data result; + result.ParseFromString(gcs_entry.entries(i)); results.emplace_back(std::move(result)); } - subscribe(client_, id, root->change_mode(), results); + subscribe(client_, id, gcs_entry.change_mode(), results); } } }; @@ -234,19 +227,17 @@ std::string Log::DebugString() const { template Status Table::Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_, + std::move(callback)); } template @@ -255,7 +246,7 @@ Status Table::Lookup(const DriverID &driver_id, const ID &id, num_lookups_++; return Log::Lookup(driver_id, id, [lookup, failure](AsyncGcsClient *client, const ID &id, - const std::vector &data) { + const std::vector &data) { if (data.empty()) { if (failure != nullptr) { (failure)(client, id); @@ -277,7 +268,7 @@ Status Table::Subscribe(const DriverID &driver_id, const ClientID &cli return Log::Subscribe( driver_id, client_id, [subscribe, failure](AsyncGcsClient *client, const ID &id, - const std::vector &data) { + const std::vector &data) { RAY_CHECK(data.empty() || data.size() == 1); if (data.size() == 1) { subscribe(client, id, data[0]); @@ -299,36 +290,30 @@ std::string Table::DebugString() const { template Status Set::Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, str.data(), str.length(), + prefix_, pubsub_channel_, std::move(callback)); } template Status Set::Remove(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_removes_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, str.data(), str.length(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -348,26 +333,16 @@ Status Hash::Update(const DriverID &driver_id, const ID &id, (done)(client_, id, data_map); } }; - flatbuffers::FlatBufferBuilder fbb; - std::vector> data_vec; - data_vec.reserve(data_map.size() * 2); - for (auto const &pair : data_map) { - // Add the key. - data_vec.push_back(fbb.CreateString(pair.first)); - flatbuffers::FlatBufferBuilder fbb_data; - fbb_data.ForceDefaults(true); - fbb_data.Finish(Data::Pack(fbb_data, pair.second.get())); - std::string data(reinterpret_cast(fbb_data.GetBufferPointer()), - fbb_data.GetSize()); - // Add the value. - data_vec.push_back(fbb.CreateString(data)); + GcsEntry gcs_entry; + gcs_entry.set_id(id.Binary()); + gcs_entry.set_change_mode(GcsChangeMode::APPEND_OR_ADD); + for (const auto &pair : data_map) { + gcs_entry.add_entries(pair.first); + gcs_entry.add_entries(pair.second->SerializeAsString()); } - - fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, - fbb.CreateString(id.Binary()), fbb.CreateVector(data_vec))); - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = gcs_entry.SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -380,19 +355,15 @@ Status Hash::RemoveEntries(const DriverID &driver_id, const ID &id, (remove_callback)(client_, id, keys); } }; - flatbuffers::FlatBufferBuilder fbb; - std::vector> data_vec; - data_vec.reserve(keys.size()); - // Add the keys. - for (auto const &key : keys) { - data_vec.push_back(fbb.CreateString(key)); + GcsEntry gcs_entry; + gcs_entry.set_id(id.Binary()); + gcs_entry.set_change_mode(GcsChangeMode::REMOVE); + for (const auto &key : keys) { + gcs_entry.add_entries(key); } - - fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::REMOVE, fbb.CreateString(id.Binary()), - fbb.CreateVector(data_vec))); - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = gcs_entry.SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -412,17 +383,15 @@ Status Hash::Lookup(const DriverID &driver_id, const ID &id, DataMap results; if (!reply.IsNil()) { const auto data = reply.ReadAsString(); - auto root = flatbuffers::GetRoot(data.data()); - RAY_CHECK(from_flatbuf(*root->id()) == id); - RAY_CHECK(root->entries()->size() % 2 == 0); - for (size_t i = 0; i < root->entries()->size(); i += 2) { - std::string key(root->entries()->Get(i)->data(), - root->entries()->Get(i)->size()); - auto result = std::make_shared(); - auto data_root = - flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); - data_root->UnPackTo(result.get()); - results.emplace(key, std::move(result)); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(reply.ReadAsString()); + RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); + RAY_CHECK(gcs_entry.entries_size() % 2 == 0); + for (int i = 0; i < gcs_entry.entries_size(); i += 2) { + const auto &key = gcs_entry.entries(i); + const auto value = std::make_shared(); + value->ParseFromString(gcs_entry.entries(i + 1)); + results.emplace(key, std::move(value)); } } lookup(client_, id, results); @@ -451,31 +420,24 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - auto root = flatbuffers::GetRoot(data.data()); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(data); + ID id = ID::FromBinary(gcs_entry.id()); DataMap data_map; - ID id; - if (root->id()->size() > 0) { - id = from_flatbuf(*root->id()); - } - if (root->change_mode() == GcsChangeMode::REMOVE) { - for (size_t i = 0; i < root->entries()->size(); i++) { - std::string key(root->entries()->Get(i)->data(), - root->entries()->Get(i)->size()); - data_map.emplace(key, std::shared_ptr()); + if (gcs_entry.change_mode() == GcsChangeMode::REMOVE) { + for (const auto &key : gcs_entry.entries()) { + data_map.emplace(key, std::shared_ptr()); } } else { - RAY_CHECK(root->entries()->size() % 2 == 0); - for (size_t i = 0; i < root->entries()->size(); i += 2) { - std::string key(root->entries()->Get(i)->data(), - root->entries()->Get(i)->size()); - auto result = std::make_shared(); - auto data_root = - flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); - data_root->UnPackTo(result.get()); - data_map.emplace(key, std::move(result)); + RAY_CHECK(gcs_entry.entries_size() % 2 == 0); + for (int i = 0; i < gcs_entry.entries_size(); i += 2) { + const auto &key = gcs_entry.entries(i); + const auto value = std::make_shared(); + value->ParseFromString(gcs_entry.entries(i + 1)); + data_map.emplace(key, std::move(value)); } } - subscribe(client_, id, root->change_mode(), data_map); + subscribe(client_, id, gcs_entry.change_mode(), data_map); } } }; @@ -490,11 +452,11 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { - auto data = std::make_shared(); - data->driver_id = driver_id.Binary(); - data->type = type; - data->error_message = error_message; - data->timestamp = timestamp; + auto data = std::make_shared(); + data->set_driver_id(driver_id.Binary()); + data->set_type(type); + data->set_error_message(error_message); + data->set_timestamp(timestamp); return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); } @@ -503,11 +465,9 @@ std::string ErrorTable::DebugString() const { } Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events) { - auto data = std::make_shared(); - // There is some room for optimization here because the Append function will just - // call "Pack" and undo the "UnPack". - profile_events.UnPackTo(data.get()); - + // TODO(hchen): Change the parameter to shared_ptr to avoid copying data. + auto data = std::make_shared(); + data->CopyFrom(profile_events); return Append(DriverID::Nil(), UniqueID::FromRandom(), data, /*done_callback=*/nullptr); } @@ -517,9 +477,9 @@ std::string ProfileTable::DebugString() const { } Status DriverTable::AppendDriverData(const DriverID &driver_id, bool is_dead) { - auto data = std::make_shared(); - data->driver_id = driver_id.Binary(); - data->is_dead = is_dead; + auto data = std::make_shared(); + data->set_driver_id(driver_id.Binary()); + data->set_is_dead(is_dead); return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); } @@ -527,7 +487,8 @@ void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callbac client_added_callback_ = callback; // Call the callback for any added clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && (entry.second.entry_type == EntryType::INSERTION)) { + if (!entry.first.IsNil() && + (entry.second.entry_type() == ClientTableData::INSERTION)) { client_added_callback_(client_, entry.first, entry.second); } } @@ -537,7 +498,7 @@ void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callb client_removed_callback_ = callback; // Call the callback for any removed clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && entry.second.entry_type == EntryType::DELETION) { + if (!entry.first.IsNil() && entry.second.entry_type() == ClientTableData::DELETION) { client_removed_callback_(client_, entry.first, entry.second); } } @@ -549,7 +510,7 @@ void ClientTable::RegisterResourceCreateUpdatedCallback( // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { if (!entry.first.IsNil() && - (entry.second.entry_type == EntryType::RES_CREATEUPDATE)) { + (entry.second.entry_type() == ClientTableData::RES_CREATEUPDATE)) { resource_createupdated_callback_(client_, entry.first, entry.second); } } @@ -559,15 +520,16 @@ void ClientTable::RegisterResourceDeletedCallback(const ClientTableCallback &cal resource_deleted_callback_ = callback; // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && entry.second.entry_type == EntryType::RES_DELETE) { + if (!entry.first.IsNil() && + entry.second.entry_type() == ClientTableData::RES_DELETE) { resource_deleted_callback_(client_, entry.first, entry.second); } } } void ClientTable::HandleNotification(AsyncGcsClient *client, - const ClientTableDataT &data) { - ClientID client_id = ClientID::FromBinary(data.client_id); + const ClientTableData &data) { + ClientID client_id = ClientID::FromBinary(data.client_id()); // It's possible to get duplicate notifications from the client table, so // check whether this notification is new. auto entry = client_cache_.find(client_id); @@ -578,16 +540,16 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } else { // If the entry is in the cache, then the notification is new if the client // was alive and is now dead or resources have been updated. - bool was_not_deleted = (entry->second.entry_type != EntryType::DELETION); - bool is_deleted = (data.entry_type == EntryType::DELETION); - bool is_res_modified = ((data.entry_type == EntryType::RES_CREATEUPDATE) || - (data.entry_type == EntryType::RES_DELETE)); + bool was_not_deleted = (entry->second.entry_type() != ClientTableData::DELETION); + bool is_deleted = (data.entry_type() == ClientTableData::DELETION); + bool is_res_modified = ((data.entry_type() == ClientTableData::RES_CREATEUPDATE) || + (data.entry_type() == ClientTableData::RES_DELETE)); is_notif_new = (was_not_deleted && (is_deleted || is_res_modified)); // Once a client with a given ID has been removed, it should never be added // again. If the entry was in the cache and the client was deleted, check // that this new notification is not an insertion. - if (entry->second.entry_type == EntryType::DELETION) { - RAY_CHECK((data.entry_type == EntryType::DELETION)) + if (entry->second.entry_type() == ClientTableData::DELETION) { + RAY_CHECK((data.entry_type() == ClientTableData::DELETION)) << "Notification for addition of a client that was already removed:" << client_id; } @@ -595,64 +557,64 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, // Add the notification to our cache. Notifications are idempotent. // If it is a new client or a client removal, add as is - if ((data.entry_type == EntryType::INSERTION) || - (data.entry_type == EntryType::DELETION)) { + if ((data.entry_type() == ClientTableData::INSERTION) || + (data.entry_type() == ClientTableData::DELETION)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable Insertion/Deletion " "notification for client id " - << client_id << ". EntryType: " << int(data.entry_type) + << client_id << ". EntryType: " << int(data.entry_type()) << ". Setting the client cache to data."; client_cache_[client_id] = data; - } else if ((data.entry_type == EntryType::RES_CREATEUPDATE) || - (data.entry_type == EntryType::RES_DELETE)) { + } else if ((data.entry_type() == ClientTableData::RES_CREATEUPDATE) || + (data.entry_type() == ClientTableData::RES_DELETE)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable RES_CREATEUPDATE " "notification for client id " - << client_id << ". EntryType: " << int(data.entry_type) + << client_id << ". EntryType: " << int(data.entry_type()) << ". Updating the client cache with the delta from the log."; - ClientTableDataT &cache_data = client_cache_[client_id]; + ClientTableData &cache_data = client_cache_[client_id]; // Iterate over all resources in the new create/update notification - for (std::vector::size_type i = 0; i != data.resources_total_label.size(); i++) { - auto const &resource_name = data.resources_total_label[i]; - auto const &capacity = data.resources_total_capacity[i]; + for (std::vector::size_type i = 0; i != data.resources_total_label_size(); i++) { + auto const &resource_name = data.resources_total_label(i); + auto const &capacity = data.resources_total_capacity(i); // If resource exists in the ClientTableData, update it, else create it auto existing_resource_label = - std::find(cache_data.resources_total_label.begin(), - cache_data.resources_total_label.end(), resource_name); - if (existing_resource_label != cache_data.resources_total_label.end()) { - auto index = std::distance(cache_data.resources_total_label.begin(), + std::find(cache_data.resources_total_label().begin(), + cache_data.resources_total_label().end(), resource_name); + if (existing_resource_label != cache_data.resources_total_label().end()) { + auto index = std::distance(cache_data.resources_total_label().begin(), existing_resource_label); // Resource already exists, set capacity if updation call.. - if (data.entry_type == EntryType::RES_CREATEUPDATE) { - cache_data.resources_total_capacity[index] = capacity; + if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { + cache_data.set_resources_total_capacity(index, capacity); } // .. delete if deletion call. - else if (data.entry_type == EntryType::RES_DELETE) { - cache_data.resources_total_label.erase( - cache_data.resources_total_label.begin() + index); - cache_data.resources_total_capacity.erase( - cache_data.resources_total_capacity.begin() + index); + else if (data.entry_type() == ClientTableData::RES_DELETE) { + cache_data.mutable_resources_total_label()->erase( + cache_data.resources_total_label().begin() + index); + cache_data.mutable_resources_total_capacity()->erase( + cache_data.resources_total_capacity().begin() + index); } } else { // Resource does not exist, create resource and add capacity if it was a resource // create call. - if (data.entry_type == EntryType::RES_CREATEUPDATE) { - cache_data.resources_total_label.push_back(resource_name); - cache_data.resources_total_capacity.push_back(capacity); + if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { + cache_data.add_resources_total_label(resource_name); + cache_data.add_resources_total_capacity(capacity); } } } } // If the notification is new, call any registered callbacks. - ClientTableDataT &cache_data = client_cache_[client_id]; + ClientTableData &cache_data = client_cache_[client_id]; if (is_notif_new) { - if (data.entry_type == EntryType::INSERTION) { + if (data.entry_type() == ClientTableData::INSERTION) { if (client_added_callback_ != nullptr) { client_added_callback_(client, client_id, cache_data); } RAY_CHECK(removed_clients_.find(client_id) == removed_clients_.end()); - } else if (data.entry_type == EntryType::DELETION) { + } else if (data.entry_type() == ClientTableData::DELETION) { // NOTE(swang): The client should be added to this data structure before // the callback gets called, in case the callback depends on the data // structure getting updated. @@ -660,11 +622,11 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, if (client_removed_callback_ != nullptr) { client_removed_callback_(client, client_id, cache_data); } - } else if (data.entry_type == EntryType::RES_CREATEUPDATE) { + } else if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { if (resource_createupdated_callback_ != nullptr) { resource_createupdated_callback_(client, client_id, cache_data); } - } else if (data.entry_type == EntryType::RES_DELETE) { + } else if (data.entry_type() == ClientTableData::RES_DELETE) { if (resource_deleted_callback_ != nullptr) { resource_deleted_callback_(client, client_id, cache_data); } @@ -672,54 +634,54 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } } -void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableDataT &data) { - auto connected_client_id = ClientID::FromBinary(data.client_id); +void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableData &data) { + auto connected_client_id = ClientID::FromBinary(data.client_id()); RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " " << client_id_; } const ClientID &ClientTable::GetLocalClientId() const { return client_id_; } -const ClientTableDataT &ClientTable::GetLocalClient() const { return local_client_; } +const ClientTableData &ClientTable::GetLocalClient() const { return local_client_; } bool ClientTable::IsRemoved(const ClientID &client_id) const { return removed_clients_.count(client_id) == 1; } -Status ClientTable::Connect(const ClientTableDataT &local_client) { +Status ClientTable::Connect(const ClientTableData &local_client) { RAY_CHECK(!disconnected_) << "Tried to reconnect a disconnected client."; - RAY_CHECK(local_client.client_id == local_client_.client_id); + RAY_CHECK(local_client.client_id() == local_client_.client_id()); local_client_ = local_client; // Construct the data to add to the client table. - auto data = std::make_shared(local_client_); - data->entry_type = EntryType::INSERTION; + auto data = std::make_shared(local_client_); + data->set_entry_type(ClientTableData::INSERTION); // Callback to handle our own successful connection once we've added // ourselves. auto add_callback = [this](AsyncGcsClient *client, const UniqueID &log_key, - const ClientTableDataT &data) { + const ClientTableData &data) { RAY_CHECK(log_key == client_log_key_); HandleConnected(client, data); // Callback for a notification from the client table. auto notification_callback = [this]( AsyncGcsClient *client, const UniqueID &log_key, - const std::vector ¬ifications) { + const std::vector ¬ifications) { RAY_CHECK(log_key == client_log_key_); - std::unordered_map connected_nodes; - std::unordered_map disconnected_nodes; + std::unordered_map connected_nodes; + std::unordered_map disconnected_nodes; for (auto ¬ification : notifications) { // This is temporary fix for Issue 4140 to avoid connect to dead nodes. // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.entry_type != EntryType::DELETION) { - connected_nodes.emplace(notification.client_id, notification); + if (notification.entry_type() != ClientTableData::DELETION) { + connected_nodes.emplace(notification.client_id(), notification); } else { - auto iter = connected_nodes.find(notification.client_id); + auto iter = connected_nodes.find(notification.client_id()); if (iter != connected_nodes.end()) { connected_nodes.erase(iter); } - disconnected_nodes.emplace(notification.client_id, notification); + disconnected_nodes.emplace(notification.client_id(), notification); } } for (const auto &pair : connected_nodes) { @@ -742,10 +704,10 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { } Status ClientTable::Disconnect(const DisconnectCallback &callback) { - auto data = std::make_shared(local_client_); - data->entry_type = EntryType::DELETION; + auto data = std::make_shared(local_client_); + data->set_entry_type(ClientTableData::DELETION); auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { + const ClientTableData &data) { HandleConnected(client, data); RAY_CHECK_OK(CancelNotifications(DriverID::Nil(), client_log_key_, id)); if (callback != nullptr) { @@ -759,24 +721,24 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { } ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) { - auto data = std::make_shared(); - data->client_id = dead_client_id.Binary(); - data->entry_type = EntryType::DELETION; + auto data = std::make_shared(); + data->set_client_id(dead_client_id.Binary()); + data->set_entry_type(ClientTableData::DELETION); return Append(DriverID::Nil(), client_log_key_, data, nullptr); } void ClientTable::GetClient(const ClientID &client_id, - ClientTableDataT &client_info) const { + ClientTableData &client_info) const { RAY_CHECK(!client_id.IsNil()); auto entry = client_cache_.find(client_id); if (entry != client_cache_.end()) { client_info = entry->second; } else { - client_info.client_id = ClientID::Nil().Binary(); + client_info.set_client_id(ClientID::Nil().Binary()); } } -const std::unordered_map &ClientTable::GetAllClients() const { +const std::unordered_map &ClientTable::GetAllClients() const { return client_cache_; } @@ -798,31 +760,29 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, const ActorCheckpointID &checkpoint_id) { auto lookup_callback = [this, checkpoint_id, driver_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id, - const ActorCheckpointIdDataT &data) { - std::shared_ptr copy = - std::make_shared(data); - copy->timestamps.push_back(current_sys_time_ms()); - copy->checkpoint_ids += checkpoint_id.Binary(); + const ActorCheckpointIdData &data) { + std::shared_ptr copy = + std::make_shared(data); + copy->add_timestamps(current_sys_time_ms()); + copy->add_checkpoint_ids(checkpoint_id.Binary()); auto num_to_keep = RayConfig::instance().num_actor_checkpoints_to_keep(); - while (copy->timestamps.size() > num_to_keep) { + while (copy->timestamps().size() > num_to_keep) { // Delete the checkpoint from actor checkpoint table. - const auto &checkpoint_id = - ActorCheckpointID::FromBinary(copy->checkpoint_ids.substr(0, kUniqueIDSize)); - RAY_LOG(DEBUG) << "Deleting checkpoint " << checkpoint_id << " for actor " - << actor_id; - copy->timestamps.erase(copy->timestamps.begin()); - copy->checkpoint_ids.erase(0, kUniqueIDSize); - client_->actor_checkpoint_table().Delete(driver_id, checkpoint_id); + const auto &to_delete = ActorCheckpointID::FromBinary(copy->checkpoint_ids(0)); + RAY_LOG(DEBUG) << "Deleting checkpoint " << to_delete << " for actor " << actor_id; + copy->mutable_checkpoint_ids()->erase(copy->mutable_checkpoint_ids()->begin()); + copy->mutable_timestamps()->erase(copy->mutable_timestamps()->begin()); + client_->actor_checkpoint_table().Delete(driver_id, to_delete); } RAY_CHECK_OK(Add(driver_id, actor_id, copy, nullptr)); }; auto failure_callback = [this, checkpoint_id, driver_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id) { - std::shared_ptr data = - std::make_shared(); - data->actor_id = id.Binary(); - data->timestamps.push_back(current_sys_time_ms()); - data->checkpoint_ids = checkpoint_id.Binary(); + std::shared_ptr data = + std::make_shared(); + data->set_actor_id(id.Binary()); + data->add_timestamps(current_sys_time_ms()); + *data->add_checkpoint_ids() = checkpoint_id.Binary(); RAY_CHECK_OK(Add(driver_id, actor_id, data, nullptr)); }; return Lookup(driver_id, actor_id, lookup_callback, failure_callback); @@ -830,8 +790,7 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, template class Log; template class Set; -template class Log; -template class Table; +template class Log; template class Table; template class Log; template class Log; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 6a1d502a7f549..2ecc3440839e2 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -11,10 +11,8 @@ #include "ray/common/status.h" #include "ray/util/logging.h" -#include "ray/gcs/format/gcs_generated.h" #include "ray/gcs/redis_context.h" -// TODO(rkn): Remove this include. -#include "ray/raylet/format/node_manager_generated.h" +#include "ray/protobuf/gcs.pb.h" struct redisAsyncContext; @@ -22,6 +20,25 @@ namespace ray { namespace gcs { +using rpc::ActorCheckpointData; +using rpc::ActorCheckpointIdData; +using rpc::ActorTableData; +using rpc::ClientTableData; +using rpc::DriverTableData; +using rpc::ErrorTableData; +using rpc::GcsChangeMode; +using rpc::GcsEntry; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatTableData; +using rpc::ObjectTableData; +using rpc::ProfileTableData; +using rpc::RayResource; +using rpc::TablePrefix; +using rpc::TablePubsub; +using rpc::TaskLeaseData; +using rpc::TaskReconstructionData; +using rpc::TaskTableData; + class RedisContext; class AsyncGcsClient; @@ -48,13 +65,12 @@ class PubsubInterface { template class LogInterface { public: - using DataT = typename Data::NativeTableType; using WriteCallback = - std::function; + std::function; virtual Status Append(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual Status AppendAt(const DriverID &driver_id, const ID &task_id, - std::shared_ptr &data, const WriteCallback &done, + std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length) = 0; virtual ~LogInterface(){}; }; @@ -72,12 +88,11 @@ class LogInterface { template class Log : public LogInterface, virtual public PubsubInterface { public: - using DataT = typename Data::NativeTableType; using Callback = std::function &data)>; - using NotificationCallback = std::function &data)>; + const std::vector &data)>; + using NotificationCallback = + std::function &data)>; /// The callback to call when a write to a key succeeds. using WriteCallback = typename LogInterface::WriteCallback; /// The callback to call when a SUBSCRIBE call completes and we are ready to @@ -86,7 +101,7 @@ class Log : public LogInterface, virtual public PubsubInterface { struct CallbackData { ID id; - std::shared_ptr data; + std::shared_ptr data; Callback callback; // An optional callback to call for subscription operations, where the // first message is a notification of subscription success. @@ -111,7 +126,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Append(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Append(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Append a log entry to a key if and only if the log has the given number @@ -126,7 +141,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param log_length The number of entries that the log must have for the /// append to succeed. /// \return Status - Status AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length); @@ -259,10 +274,9 @@ class Log : public LogInterface, virtual public PubsubInterface { template class TableInterface { public: - using DataT = typename Data::NativeTableType; using WriteCallback = typename Log::WriteCallback; virtual Status Add(const DriverID &driver_id, const ID &task_id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual ~TableInterface(){}; }; @@ -280,9 +294,8 @@ class Table : private Log, public TableInterface, virtual public PubsubInterface { public: - using DataT = typename Log::DataT; using Callback = - std::function; + std::function; using WriteCallback = typename Log::WriteCallback; /// The callback to call when a Lookup call returns an empty entry. using FailureCallback = std::function; @@ -305,7 +318,7 @@ class Table : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Lookup an entry asynchronously. @@ -369,12 +382,11 @@ class Table : private Log, template class SetInterface { public: - using DataT = typename Data::NativeTableType; using WriteCallback = typename Log::WriteCallback; - virtual Status Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + virtual Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + const WriteCallback &done) = 0; virtual Status Remove(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual ~SetInterface(){}; }; @@ -392,7 +404,6 @@ class Set : private Log, public SetInterface, virtual public PubsubInterface { public: - using DataT = typename Log::DataT; using Callback = typename Log::Callback; using WriteCallback = typename Log::WriteCallback; using NotificationCallback = typename Log::NotificationCallback; @@ -414,7 +425,7 @@ class Set : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Remove an entry from the set. @@ -425,7 +436,7 @@ class Set : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); Status Subscribe(const DriverID &driver_id, const ClientID &client_id, @@ -454,8 +465,7 @@ class Set : private Log, template class HashInterface { public: - using DataT = typename Data::NativeTableType; - using DataMap = std::unordered_map>; + using DataMap = std::unordered_map>; // Reuse Log's SubscriptionCallback when Subscribe is successfully called. using SubscriptionCallback = typename Log::SubscriptionCallback; @@ -544,8 +554,7 @@ class Hash : private Log, public HashInterface, virtual public PubsubInterface { public: - using DataT = typename Log::DataT; - using DataMap = std::unordered_map>; + using DataMap = std::unordered_map>; using HashCallback = typename HashInterface::HashCallback; using HashRemoveCallback = typename HashInterface::HashRemoveCallback; using HashNotificationCallback = @@ -595,7 +604,7 @@ class DynamicResourceTable : public Hash { DynamicResourceTable(const std::vector> &contexts, AsyncGcsClient *client) : Hash(contexts, client) { - pubsub_channel_ = TablePubsub::NODE_RESOURCE; + pubsub_channel_ = TablePubsub::NODE_RESOURCE_PUBSUB; prefix_ = TablePrefix::NODE_RESOURCE; }; @@ -607,7 +616,7 @@ class ObjectTable : public Set { ObjectTable(const std::vector> &contexts, AsyncGcsClient *client) : Set(contexts, client) { - pubsub_channel_ = TablePubsub::OBJECT; + pubsub_channel_ = TablePubsub::OBJECT_PUBSUB; prefix_ = TablePrefix::OBJECT; }; @@ -619,7 +628,7 @@ class HeartbeatTable : public Table { HeartbeatTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::HEARTBEAT; + pubsub_channel_ = TablePubsub::HEARTBEAT_PUBSUB; prefix_ = TablePrefix::HEARTBEAT; } virtual ~HeartbeatTable() {} @@ -630,7 +639,7 @@ class HeartbeatBatchTable : public Table { HeartbeatBatchTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::HEARTBEAT_BATCH; + pubsub_channel_ = TablePubsub::HEARTBEAT_BATCH_PUBSUB; prefix_ = TablePrefix::HEARTBEAT_BATCH; } virtual ~HeartbeatBatchTable() {} @@ -641,7 +650,7 @@ class DriverTable : public Log { DriverTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::DRIVER; + pubsub_channel_ = TablePubsub::DRIVER_PUBSUB; prefix_ = TablePrefix::DRIVER; }; @@ -655,18 +664,6 @@ class DriverTable : public Log { Status AppendDriverData(const DriverID &driver_id, bool is_dead); }; -class FunctionTable : public Table { - public: - FunctionTable(const std::vector> &contexts, - AsyncGcsClient *client) - : Table(contexts, client) { - pubsub_channel_ = TablePubsub::NO_PUBLISH; - prefix_ = TablePrefix::FUNCTION; - }; -}; - -using ClassTable = Table; - /// Actor table starts with an ALIVE entry, which represents the first time the actor /// is created. This may be followed by 0 or more pairs of RECONSTRUCTING, ALIVE entries, /// which represent each time the actor fails (RECONSTRUCTING) and gets recreated (ALIVE). @@ -677,7 +674,7 @@ class ActorTable : public Log { ActorTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::ACTOR; + pubsub_channel_ = TablePubsub::ACTOR_PUBSUB; prefix_ = TablePrefix::ACTOR; } }; @@ -696,12 +693,12 @@ class TaskLeaseTable : public Table { TaskLeaseTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::TASK_LEASE; + pubsub_channel_ = TablePubsub::TASK_LEASE_PUBSUB; prefix_ = TablePrefix::TASK_LEASE; } Status Add(const DriverID &driver_id, const TaskID &id, - std::shared_ptr &data, const WriteCallback &done) override { + std::shared_ptr &data, const WriteCallback &done) override { RAY_RETURN_NOT_OK((Table::Add(driver_id, id, data, done))); // Mark the entry for expiration in Redis. It's okay if this command fails // since the lease entry itself contains the expiration period. In the @@ -709,9 +706,8 @@ class TaskLeaseTable : public Table { // entry will overestimate the expiration time. // TODO(swang): Use a common helper function to format the key instead of // hardcoding it to match the Redis module. - std::vector args = {"PEXPIRE", - EnumNameTablePrefix(prefix_) + id.Binary(), - std::to_string(data->timeout)}; + std::vector args = {"PEXPIRE", TablePrefix_Name(prefix_) + id.Binary(), + std::to_string(data->timeout())}; return GetRedisContext(id)->RunArgvAsync(args); } @@ -747,12 +743,12 @@ class ActorCheckpointIdTable : public Table { namespace raylet { -class TaskTable : public Table { +class TaskTable : public Table { public: TaskTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::RAYLET_TASK; + pubsub_channel_ = TablePubsub::RAYLET_TASK_PUBSUB; prefix_ = TablePrefix::RAYLET_TASK; } @@ -770,7 +766,7 @@ class ErrorTable : private Log { ErrorTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::ERROR_INFO; + pubsub_channel_ = TablePubsub::ERROR_INFO_PUBSUB; prefix_ = TablePrefix::ERROR_INFO; }; @@ -815,10 +811,6 @@ class ProfileTable : private Log { std::string DebugString() const; }; -using CustomSerializerTable = Table; - -using ConfigTable = Table; - /// \class ClientTable /// /// The ClientTable stores information about active and inactive clients. It is @@ -831,7 +823,7 @@ using ConfigTable = Table; class ClientTable : public Log { public: using ClientTableCallback = std::function; + AsyncGcsClient *client, const ClientID &id, const ClientTableData &data)>; using DisconnectCallback = std::function; ClientTable(const std::vector> &contexts, AsyncGcsClient *client, const ClientID &client_id) @@ -842,11 +834,11 @@ class ClientTable : public Log { disconnected_(false), client_id_(client_id), local_client_() { - pubsub_channel_ = TablePubsub::CLIENT; + pubsub_channel_ = TablePubsub::CLIENT_PUBSUB; prefix_ = TablePrefix::CLIENT; // Set the local client's ID. - local_client_.client_id = client_id.Binary(); + local_client_.set_client_id(client_id.Binary()); }; /// Connect as a client to the GCS. This registers us in the client table @@ -855,7 +847,7 @@ class ClientTable : public Log { /// \param Information about the connecting client. This must have the /// same client_id as the one set in the client table. /// \return Status - ray::Status Connect(const ClientTableDataT &local_client); + ray::Status Connect(const ClientTableData &local_client); /// Disconnect the client from the GCS. The client ID assigned during /// registration should never be reused after disconnecting. @@ -898,7 +890,7 @@ class ClientTable : public Log { /// about the client in the cache, then the reference will be modified to /// contain that information. Else, the reference will be updated to contain /// a nil client ID. - void GetClient(const ClientID &client, ClientTableDataT &client_info) const; + void GetClient(const ClientID &client, ClientTableData &client_info) const; /// Get the local client's ID. /// @@ -908,7 +900,7 @@ class ClientTable : public Log { /// Get the local client's information. /// /// \return The local client's information. - const ClientTableDataT &GetLocalClient() const; + const ClientTableData &GetLocalClient() const; /// Check whether the given client is removed. /// @@ -919,7 +911,7 @@ class ClientTable : public Log { /// Get the information of all clients. /// /// \return The client ID to client information map. - const std::unordered_map &GetAllClients() const; + const std::unordered_map &GetAllClients() const; /// Lookup the client data in the client table. /// @@ -940,15 +932,15 @@ class ClientTable : public Log { private: /// Handle a client table notification. - void HandleNotification(AsyncGcsClient *client, const ClientTableDataT ¬ifications); + void HandleNotification(AsyncGcsClient *client, const ClientTableData ¬ifications); /// Handle this client's successful connection to the GCS. - void HandleConnected(AsyncGcsClient *client, const ClientTableDataT &client_data); + void HandleConnected(AsyncGcsClient *client, const ClientTableData &client_data); /// Whether this client has called Disconnect(). bool disconnected_; /// This client's ID. const ClientID client_id_; /// Information about this client. - ClientTableDataT local_client_; + ClientTableData local_client_; /// The callback to call when a new client is added. ClientTableCallback client_added_callback_; /// The callback to call when a client is removed. @@ -958,7 +950,7 @@ class ClientTable : public Log { /// The callback to call when a resource is deleted. ClientTableCallback resource_deleted_callback_; /// A cache for information about all clients. - std::unordered_map client_cache_; + std::unordered_map client_cache_; /// The set of removed clients. std::unordered_set removed_clients_; }; diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 5b6794a505d3f..454379d183024 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -8,18 +8,22 @@ ObjectDirectory::ObjectDirectory(boost::asio::io_service &io_service, namespace { +using ray::rpc::ClientTableData; +using ray::rpc::GcsChangeMode; +using ray::rpc::ObjectTableData; + /// Process a notification of the object table entries and store the result in /// client_ids. This assumes that client_ids already contains the result of the /// object table entries up to but not including this notification. void UpdateObjectLocations(const GcsChangeMode change_mode, - const std::vector &location_updates, + const std::vector &location_updates, const ray::gcs::ClientTable &client_table, std::unordered_set *client_ids) { // location_updates contains the updates of locations of the object. // with GcsChangeMode, we can determine whether the update mode is // addition or deletion. for (const auto &object_table_data : location_updates) { - ClientID client_id = ClientID::FromBinary(object_table_data.manager); + ClientID client_id = ClientID::FromBinary(object_table_data.manager()); if (change_mode != GcsChangeMode::REMOVE) { client_ids->insert(client_id); } else { @@ -42,7 +46,7 @@ void ObjectDirectory::RegisterBackend() { auto object_notification_callback = [this](gcs::AsyncGcsClient *client, const ObjectID &object_id, const GcsChangeMode change_mode, - const std::vector &location_updates) { + const std::vector &location_updates) { // Objects are added to this map in SubscribeObjectLocations. auto it = listeners_.find(object_id); // Do nothing for objects we are not listening for. @@ -79,9 +83,9 @@ ray::Status ObjectDirectory::ReportObjectAdded( const object_manager::protocol::ObjectInfoT &object_info) { RAY_LOG(DEBUG) << "Reporting object added to GCS " << object_id; // Append the addition entry to the object table. - auto data = std::make_shared(); - data->manager = client_id.Binary(); - data->object_size = object_info.data_size; + auto data = std::make_shared(); + data->set_manager(client_id.Binary()); + data->set_object_size(object_info.data_size); ray::Status status = gcs_client_->object_table().Add(DriverID::Nil(), object_id, data, nullptr); return status; @@ -92,9 +96,9 @@ ray::Status ObjectDirectory::ReportObjectRemoved( const object_manager::protocol::ObjectInfoT &object_info) { RAY_LOG(DEBUG) << "Reporting object removed to GCS " << object_id; // Append the eviction entry to the object table. - auto data = std::make_shared(); - data->manager = client_id.Binary(); - data->object_size = object_info.data_size; + auto data = std::make_shared(); + data->set_manager(client_id.Binary()); + data->set_object_size(object_info.data_size); ray::Status status = gcs_client_->object_table().Remove(DriverID::Nil(), object_id, data, nullptr); return status; @@ -102,14 +106,14 @@ ray::Status ObjectDirectory::ReportObjectRemoved( void ObjectDirectory::LookupRemoteConnectionInfo( RemoteConnectionInfo &connection_info) const { - ClientTableDataT client_data; + ClientTableData client_data; gcs_client_->client_table().GetClient(connection_info.client_id, client_data); - ClientID result_client_id = ClientID::FromBinary(client_data.client_id); + ClientID result_client_id = ClientID::FromBinary(client_data.client_id()); if (!result_client_id.IsNil()) { RAY_CHECK(result_client_id == connection_info.client_id); - if (client_data.entry_type == EntryType::INSERTION) { - connection_info.ip = client_data.node_manager_address; - connection_info.port = static_cast(client_data.object_manager_port); + if (client_data.entry_type() == ClientTableData::INSERTION) { + connection_info.ip = client_data.node_manager_address(); + connection_info.port = static_cast(client_data.object_manager_port()); } } } @@ -208,7 +212,7 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, status = gcs_client_->object_table().Lookup( DriverID::Nil(), object_id, [this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id, - const std::vector &location_updates) { + const std::vector &location_updates) { // Build the set of current locations based on the entries in the log. std::unordered_set client_ids; UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, location_updates, diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 954162c21aef2..964cee605cede 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -309,15 +309,15 @@ void ObjectManager::HandleSendFinished(const ObjectID &object_id, // TODO(rkn): What do we want to do if the send failed? } - ProfileEventT profile_event; - profile_event.event_type = "transfer_send"; - profile_event.start_time = start_time; - profile_event.end_time = end_time; + rpc::ProfileTableData::ProfileEvent profile_event; + profile_event.set_event_type("transfer_send"); + profile_event.set_start_time(start_time); + profile_event.set_end_time(end_time); // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + - std::to_string(chunk_index) + ",\"" + status.ToString() + - "\"]"; + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + + "\"," + std::to_string(chunk_index) + ",\"" + + status.ToString() + "\"]"); profile_events_.push_back(profile_event); } @@ -329,15 +329,15 @@ void ObjectManager::HandleReceiveFinished(const ObjectID &object_id, // TODO(rkn): What do we want to do if the send failed? } - ProfileEventT profile_event; - profile_event.event_type = "transfer_receive"; - profile_event.start_time = start_time; - profile_event.end_time = end_time; + rpc::ProfileTableData::ProfileEvent profile_event; + profile_event.set_event_type("transfer_receive"); + profile_event.set_start_time(start_time); + profile_event.set_end_time(end_time); // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + - std::to_string(chunk_index) + ",\"" + status.ToString() + - "\"]"; + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + + "\"," + std::to_string(chunk_index) + ",\"" + + status.ToString() + "\"]"); profile_events_.push_back(profile_event); } @@ -801,11 +801,12 @@ void ObjectManager::ReceivePullRequest(std::shared_ptr &con ObjectID object_id = ObjectID::FromBinary(pr->object_id()->str()); ClientID client_id = ClientID::FromBinary(pr->client_id()->str()); - ProfileEventT profile_event; - profile_event.event_type = "receive_pull_request"; - profile_event.start_time = current_sys_time_seconds(); - profile_event.end_time = profile_event.start_time; - profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"]"; + rpc::ProfileTableData::ProfileEvent profile_event; + profile_event.set_event_type("receive_pull_request"); + profile_event.set_start_time(current_sys_time_seconds()); + profile_event.set_end_time(profile_event.start_time()); + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + + "\"]"); profile_events_.push_back(profile_event); Push(object_id, client_id); @@ -938,13 +939,13 @@ void ObjectManager::SpreadFreeObjectRequest(const std::vector &object_ } } -ProfileTableDataT ObjectManager::GetAndResetProfilingInfo() { - ProfileTableDataT profile_info; - profile_info.component_type = "object_manager"; - profile_info.component_id = client_id_.Binary(); +rpc::ProfileTableData ObjectManager::GetAndResetProfilingInfo() { + rpc::ProfileTableData profile_info; + profile_info.set_component_type("object_manager"); + profile_info.set_component_id(client_id_.Binary()); for (auto const &profile_event : profile_events_) { - profile_info.profile_events.emplace_back(new ProfileEventT(profile_event)); + profile_info.add_profile_events()->CopyFrom(profile_event); } profile_events_.clear(); diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 6318250ae3e80..6664dd0a93bde 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -180,7 +180,7 @@ class ObjectManager : public ObjectManagerInterface { /// /// \return All profiling information that has accumulated since the last call /// to this method. - ProfileTableDataT GetAndResetProfilingInfo(); + rpc::ProfileTableData GetAndResetProfilingInfo(); /// Returns debug string for class. /// @@ -412,7 +412,7 @@ class ObjectManager : public ObjectManagerInterface { /// Profiling events that are to be batched together and added to the profile /// table in the GCS. - std::vector profile_events_; + std::vector profile_events_; /// Internally maintained random number generator. std::mt19937_64 gen_; diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index 55aa59124a999..2d5292842acf8 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -11,6 +11,8 @@ namespace ray { +using rpc::ClientTableData; + std::string store_executable; static inline void flushall_redis(void) { @@ -52,10 +54,10 @@ class MockServer { std::string ip = endpoint.address().to_string(); unsigned short object_manager_port = endpoint.port(); - ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); - client_info.node_manager_address = ip; - client_info.node_manager_port = object_manager_port; - client_info.object_manager_port = object_manager_port; + ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); + client_info.set_node_manager_address(ip); + client_info.set_node_manager_port(object_manager_port); + client_info.set_object_manager_port(object_manager_port); ray::Status status = gcs_client_->client_table().Connect(client_info); object_manager_.RegisterGcs(); return status; @@ -242,8 +244,8 @@ class StressTestObjectManager : public TestObjectManagerBase { client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback( [this](gcs::AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); + const ClientTableData &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id()); if (parsed_id == client_id_1 || parsed_id == client_id_2) { num_connected_clients += 1; } @@ -438,16 +440,16 @@ class StressTestObjectManager : public TestObjectManagerBase { RAY_LOG(DEBUG) << "\n" << "All connected clients:" << "\n"; - ClientTableDataT data; + ClientTableData data; gcs_client_1->client_table().GetClient(client_id_1, data); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data.client_id) << "\n" - << "ClientIp=" << data.node_manager_address << "\n" - << "ClientPort=" << data.node_manager_port; - ClientTableDataT data2; + RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data.client_id()) << "\n" + << "ClientIp=" << data.node_manager_address() << "\n" + << "ClientPort=" << data.node_manager_port(); + ClientTableData data2; gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data2.client_id) << "\n" - << "ClientIp=" << data2.node_manager_address << "\n" - << "ClientPort=" << data2.node_manager_port; + RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data2.client_id()) << "\n" + << "ClientIp=" << data2.node_manager_address() << "\n" + << "ClientPort=" << data2.node_manager_port(); } }; diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index ee6c78d8ed42b..45b80a267f2f2 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -14,6 +14,8 @@ int64_t wait_timeout_ms; namespace ray { +using rpc::ClientTableData; + static inline void flushall_redis(void) { redisContext *context = redisConnect("127.0.0.1", 6379); freeReplyObject(redisCommand(context, "FLUSHALL")); @@ -46,10 +48,10 @@ class MockServer { std::string ip = endpoint.address().to_string(); unsigned short object_manager_port = endpoint.port(); - ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); - client_info.node_manager_address = ip; - client_info.node_manager_port = object_manager_port; - client_info.object_manager_port = object_manager_port; + ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); + client_info.set_node_manager_address(ip); + client_info.set_node_manager_port(object_manager_port); + client_info.set_object_manager_port(object_manager_port); ray::Status status = gcs_client_->client_table().Connect(client_info); object_manager_.RegisterGcs(); return status; @@ -221,8 +223,8 @@ class TestObjectManager : public TestObjectManagerBase { client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback( [this](gcs::AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); + const ClientTableData &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id()); if (parsed_id == client_id_1 || parsed_id == client_id_2) { num_connected_clients += 1; } @@ -457,19 +459,19 @@ class TestObjectManager : public TestObjectManagerBase { RAY_LOG(DEBUG) << "\n" << "Server client ids:" << "\n"; - ClientTableDataT data; + ClientTableData data; gcs_client_1->client_table().GetClient(client_id_1, data); - RAY_LOG(DEBUG) << (ClientID::FromBinary(data.client_id).IsNil()); - RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::FromBinary(data.client_id); - RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address; - RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port; - ASSERT_EQ(client_id_1, ClientID::FromBinary(data.client_id)); - ClientTableDataT data2; + RAY_LOG(DEBUG) << (ClientID::FromBinary(data.client_id()).IsNil()); + RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::FromBinary(data.client_id()); + RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address(); + RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port(); + ASSERT_EQ(client_id_1, ClientID::FromBinary(data.client_id())); + ClientTableData data2; gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::FromBinary(data2.client_id); - RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address; - RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port; - ASSERT_EQ(client_id_2, ClientID::FromBinary(data2.client_id)); + RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::FromBinary(data2.client_id()); + RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address(); + RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port(); + ASSERT_EQ(client_id_2, ClientID::FromBinary(data2.client_id())); } }; diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto new file mode 100644 index 0000000000000..d0b2c5e007fe3 --- /dev/null +++ b/src/ray/protobuf/gcs.proto @@ -0,0 +1,280 @@ +syntax = "proto3"; + +package ray.rpc; + +option java_package = "org.ray.runtime.generated"; + +// Language of a worker or task. +enum Language { + PYTHON = 0; + CPP = 1; + JAVA = 2; +} + +// These indexes are mapped to strings in ray_redis_module.cc. +enum TablePrefix { + TABLE_PREFIX_MIN = 0; + UNUSED = 1; + TASK = 2; + RAYLET_TASK = 3; + CLIENT = 4; + OBJECT = 5; + ACTOR = 6; + FUNCTION = 7; + TASK_RECONSTRUCTION = 8; + HEARTBEAT = 9; + HEARTBEAT_BATCH = 10; + ERROR_INFO = 11; + DRIVER = 12; + PROFILE = 13; + TASK_LEASE = 14; + ACTOR_CHECKPOINT = 15; + ACTOR_CHECKPOINT_ID = 16; + NODE_RESOURCE = 17; + TABLE_PREFIX_MAX = 18; +} + +// The channel that Add operations to the Table should be published on, if any. +enum TablePubsub { + TABLE_PUBSUB_MIN = 0; + NO_PUBLISH = 1; + TASK_PUBSUB = 2; + RAYLET_TASK_PUBSUB = 3; + CLIENT_PUBSUB = 4; + OBJECT_PUBSUB = 5; + ACTOR_PUBSUB = 6; + HEARTBEAT_PUBSUB = 7; + HEARTBEAT_BATCH_PUBSUB = 8; + ERROR_INFO_PUBSUB = 9; + TASK_LEASE_PUBSUB = 10; + DRIVER_PUBSUB = 11; + NODE_RESOURCE_PUBSUB = 12; + TABLE_PUBSUB_MAX = 13; +} + +enum GcsChangeMode { + APPEND_OR_ADD = 0; + REMOVE = 1; +} + +message GcsEntry { + GcsChangeMode change_mode = 1; + bytes id = 2; + repeated bytes entries = 3; +} + +message ObjectTableData { + // The size of the object. + uint64 object_size = 1; + // The node manager ID that this object appeared on or was evicted by. + bytes manager = 2; +} + +message TaskReconstructionData { + // The number of times this task has been reconstructed so far. + uint64 num_reconstructions = 1; + // The node manager that is trying to reconstruct the task. + bytes node_manager_id = 2; +} + +// TODO(hchen): Task table currently still uses flatbuffers-defined data structure +// (`Task` in `node_manager.fbs`), because a lot of code depends on that. This should +// be migrated to protobuf very soon. +message TaskTableData { + // Flatbuffers-serialized content of the task, see `src/ray/raylet/task.h`. + bytes task = 1; +} + +message ActorTableData { + // State of an actor. + enum ActorState { + // Actor is alive. + ALIVE = 0; + // Actor is dead, now being reconstructed. + // After reconstruction finishes, the state will become alive again. + RECONSTRUCTING = 1; + // Actor is already dead and won't be reconstructed. + DEAD = 2; + } + // The ID of the actor that was created. + bytes actor_id = 1; + // The dummy object ID returned by the actor creation task. If the actor + // dies, then this is the object that should be reconstructed for the actor + // to be recreated. + bytes actor_creation_dummy_object_id = 2; + // The ID of the driver that created the actor. + bytes driver_id = 3; + // The ID of the node manager that created the actor. + bytes node_manager_id = 4; + // Current state of this actor. + ActorState state = 5; + // Max number of times this actor should be reconstructed. + uint64 max_reconstructions = 6; + // Remaining number of reconstructions. + uint64 remaining_reconstructions = 7; +} + +message ErrorTableData { + // The ID of the driver that the error is for. + bytes driver_id = 1; + // The type of the error. + string type = 2; + // The error message. + string error_message = 3; + // The timestamp of the error message. + double timestamp = 4; +} + +message ProfileTableData { + // Represents a profile event. + message ProfileEvent { + // The type of the event. + string event_type = 1; + // The start time of the event. + double start_time = 2; + // The end time of the event. If the event is a point event, then this should + // be the same as the start time. + double end_time = 3; + // Additional data associated with the event. This data must be serialized + // using JSON. + string extra_data = 4; + } + + // The type of the component that generated the event, e.g., worker or + // object_manager, or node_manager. + string component_type = 1; + // An identifier for the component that generated the event. + bytes component_id = 2; + // An identifier for the node that generated the event. + string node_ip_address = 3; + // This is a batch of profiling events. We batch these together for + // performance reasons because a single task may generate many events, and + // we don't want each event to require a GCS command. + repeated ProfileEvent profile_events = 4; +} + +message RayResource { + // The type of the resource. + string resource_name = 1; + // The total capacity of this resource type. + double resource_capacity = 2; +} + +message ClientTableData { + // Enum for the entry type in the ClientTable + enum EntryType { + INSERTION = 0; + DELETION = 1; + RES_CREATEUPDATE = 2; + RES_DELETE = 3; + } + + // The client ID of the client that the message is about. + bytes client_id = 1; + // The IP address of the client's node manager. + string node_manager_address = 2; + // The IPC socket name of the client's raylet. + string raylet_socket_name = 3; + // The IPC socket name of the client's plasma store. + string object_store_socket_name = 4; + // The port at which the client's node manager is listening for TCP + // connections from other node managers. + int32 node_manager_port = 5; + // The port at which the client's object manager is listening for TCP + // connections from other object managers. + int32 object_manager_port = 6; + // Enum to store the entry type in the log + EntryType entry_type = 7; + + // TODO(hchen): Define the following resources in map format. + repeated string resources_total_label = 8; + repeated double resources_total_capacity = 9; +} + +message HeartbeatTableData { + // Node manager client id + bytes client_id = 1; + // TODO(hchen): Define the following resources in map format. + // Resource capacity currently available on this node manager. + repeated string resources_available_label = 2; + repeated double resources_available_capacity = 3; + // Total resource capacity configured for this node manager. + repeated string resources_total_label = 4; + repeated double resources_total_capacity = 5; + // Aggregate outstanding resource load on this node manager. + repeated string resource_load_label = 6; + repeated double resource_load_capacity = 7; +} + +message HeartbeatBatchTableData { + repeated HeartbeatTableData batch = 1; +} + +// Data for a lease on task execution. +message TaskLeaseData { + // Node manager client ID. + bytes node_manager_id = 1; + // The time that the lease was last acquired at. NOTE(swang): This is the + // system clock time according to the node that added the entry and is not + // synchronized with other nodes. + uint64 acquired_at = 2; + // The period that the lease is active for. + uint64 timeout = 3; +} + +message DriverTableData { + // The driver ID. + bytes driver_id = 1; + // Whether it's dead. + bool is_dead = 2; +} + +// This table stores the actor checkpoint data. An actor checkpoint +// is the snapshot of an actor's state in the actor registration. +// See `actor_registration.h` for more detailed explanation of these fields. +message ActorCheckpointData { + // ID of this actor. + bytes actor_id = 1; + // The dummy object ID of actor's most recently executed task. + bytes execution_dependency = 2; + // A list of IDs of this actor's handles. + repeated bytes handle_ids = 3; + // The task counters of the above handles. + repeated uint64 task_counters = 4; + // The frontier dependencies of the above handles. + repeated bytes frontier_dependencies = 5; + // A list of unreleased dummy objects from this actor. + repeated bytes unreleased_dummy_objects = 6; + // The numbers of dependencies for the above unreleased dummy objects. + repeated uint32 num_dummy_object_dependencies = 7; +} + +// This table stores the actor-to-available-checkpoint-ids mapping. +message ActorCheckpointIdData { + // ID of this actor. + bytes actor_id = 1; + // IDs of this actor's available checkpoints. + repeated bytes checkpoint_ids = 2; + // A list of the timestamps for each of the above `checkpoint_ids`. + repeated uint64 timestamps = 3; +} + +// This enum type is used as object's metadata to indicate the object's creating +// task has failed because of a certain error. +// TODO(hchen): We may want to make these errors more specific. E.g., we may want +// to distinguish between intentional and expected actor failures, and between +// worker process failure and node failure. +enum ErrorType { + // Indicates that a task failed because the worker died unexpectedly while executing it. + WORKER_DIED = 0; + // Indicates that a task failed because the actor died unexpectedly before finishing it. + ACTOR_DIED = 1; + // Indicates that an object is lost and cannot be reconstructed. + // Note, this currently only happens to actor objects. When the actor's state is already + // after the object's creating task, the actor cannot re-run the task. + // TODO(hchen): we may want to reuse this error type for more cases. E.g., + // 1) A object that was put by the driver. + // 2) The object's creating task is already cleaned up from GCS (this currently + // crashes raylet). + OBJECT_UNRECONSTRUCTABLE = 2; +} diff --git a/src/ray/raylet/actor_registration.cc b/src/ray/raylet/actor_registration.cc index cc587bc4d74e6..7f940006b5bee 100644 --- a/src/ray/raylet/actor_registration.cc +++ b/src/ray/raylet/actor_registration.cc @@ -8,34 +8,35 @@ namespace ray { namespace raylet { -ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data) +ActorRegistration::ActorRegistration(const ActorTableData &actor_table_data) : actor_table_data_(actor_table_data) {} -ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data, - const ActorCheckpointDataT &checkpoint_data) +ActorRegistration::ActorRegistration(const ActorTableData &actor_table_data, + const ActorCheckpointData &checkpoint_data) : actor_table_data_(actor_table_data), - execution_dependency_(ObjectID::FromBinary(checkpoint_data.execution_dependency)) { + execution_dependency_( + ObjectID::FromBinary(checkpoint_data.execution_dependency())) { // Restore `frontier_`. - for (size_t i = 0; i < checkpoint_data.handle_ids.size(); i++) { - auto handle_id = ActorHandleID::FromBinary(checkpoint_data.handle_ids[i]); + for (size_t i = 0; i < checkpoint_data.handle_ids_size(); i++) { + auto handle_id = ActorHandleID::FromBinary(checkpoint_data.handle_ids(i)); auto &frontier_entry = frontier_[handle_id]; - frontier_entry.task_counter = checkpoint_data.task_counters[i]; + frontier_entry.task_counter = checkpoint_data.task_counters(i); frontier_entry.execution_dependency = - ObjectID::FromBinary(checkpoint_data.frontier_dependencies[i]); + ObjectID::FromBinary(checkpoint_data.frontier_dependencies(i)); } // Restore `dummy_objects_`. - for (size_t i = 0; i < checkpoint_data.unreleased_dummy_objects.size(); i++) { - auto dummy = ObjectID::FromBinary(checkpoint_data.unreleased_dummy_objects[i]); - dummy_objects_[dummy] = checkpoint_data.num_dummy_object_dependencies[i]; + for (size_t i = 0; i < checkpoint_data.unreleased_dummy_objects_size(); i++) { + auto dummy = ObjectID::FromBinary(checkpoint_data.unreleased_dummy_objects(i)); + dummy_objects_[dummy] = checkpoint_data.num_dummy_object_dependencies(i); } } const ClientID ActorRegistration::GetNodeManagerId() const { - return ClientID::FromBinary(actor_table_data_.node_manager_id); + return ClientID::FromBinary(actor_table_data_.node_manager_id()); } const ObjectID ActorRegistration::GetActorCreationDependency() const { - return ObjectID::FromBinary(actor_table_data_.actor_creation_dummy_object_id); + return ObjectID::FromBinary(actor_table_data_.actor_creation_dummy_object_id()); } const ObjectID ActorRegistration::GetExecutionDependency() const { @@ -43,15 +44,15 @@ const ObjectID ActorRegistration::GetExecutionDependency() const { } const DriverID ActorRegistration::GetDriverId() const { - return DriverID::FromBinary(actor_table_data_.driver_id); + return DriverID::FromBinary(actor_table_data_.driver_id()); } const int64_t ActorRegistration::GetMaxReconstructions() const { - return actor_table_data_.max_reconstructions; + return actor_table_data_.max_reconstructions(); } const int64_t ActorRegistration::GetRemainingReconstructions() const { - return actor_table_data_.remaining_reconstructions; + return actor_table_data_.remaining_reconstructions(); } const std::unordered_map @@ -96,7 +97,7 @@ void ActorRegistration::AddHandle(const ActorHandleID &handle_id, int ActorRegistration::NumHandles() const { return frontier_.size(); } -std::shared_ptr ActorRegistration::GenerateCheckpointData( +std::shared_ptr ActorRegistration::GenerateCheckpointData( const ActorID &actor_id, const Task &task) { const auto actor_handle_id = task.GetTaskSpecification().ActorHandleId(); const auto dummy_object = task.GetTaskSpecification().ActorDummyObject(); @@ -109,18 +110,18 @@ std::shared_ptr ActorRegistration::GenerateCheckpointData( copy.ExtendFrontier(actor_handle_id, dummy_object); // Use actor's current state to generate checkpoint data. - auto checkpoint_data = std::make_shared(); - checkpoint_data->actor_id = actor_id.Binary(); - checkpoint_data->execution_dependency = copy.GetExecutionDependency().Binary(); + auto checkpoint_data = std::make_shared(); + checkpoint_data->set_actor_id(actor_id.Binary()); + checkpoint_data->set_execution_dependency(copy.GetExecutionDependency().Binary()); for (const auto &frontier : copy.GetFrontier()) { - checkpoint_data->handle_ids.push_back(frontier.first.Binary()); - checkpoint_data->task_counters.push_back(frontier.second.task_counter); - checkpoint_data->frontier_dependencies.push_back( + checkpoint_data->add_handle_ids(frontier.first.Binary()); + checkpoint_data->add_task_counters(frontier.second.task_counter); + checkpoint_data->add_frontier_dependencies( frontier.second.execution_dependency.Binary()); } for (const auto &entry : copy.GetDummyObjects()) { - checkpoint_data->unreleased_dummy_objects.push_back(entry.first.Binary()); - checkpoint_data->num_dummy_object_dependencies.push_back(entry.second); + checkpoint_data->add_unreleased_dummy_objects(entry.first.Binary()); + checkpoint_data->add_num_dummy_object_dependencies(entry.second); } return checkpoint_data; } diff --git a/src/ray/raylet/actor_registration.h b/src/ray/raylet/actor_registration.h index 8d7ce2a449ecc..208e4998263fc 100644 --- a/src/ray/raylet/actor_registration.h +++ b/src/ray/raylet/actor_registration.h @@ -4,13 +4,17 @@ #include #include "ray/common/id.h" -#include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/gcs.pb.h" #include "ray/raylet/task.h" namespace ray { namespace raylet { +using rpc::ActorTableData; +using ActorState = rpc::ActorTableData::ActorState; +using rpc::ActorCheckpointData; + /// \class ActorRegistration /// /// Information about an actor registered in the system. This includes the @@ -23,13 +27,13 @@ class ActorRegistration { /// /// \param actor_table_data Information from the global actor table about /// this actor. This includes the actor's node manager location. - ActorRegistration(const ActorTableDataT &actor_table_data); + explicit ActorRegistration(const ActorTableData &actor_table_data); /// Recreate an actor's registration from a checkpoint. /// /// \param checkpoint_data The checkpoint used to restore the actor. - ActorRegistration(const ActorTableDataT &actor_table_data, - const ActorCheckpointDataT &checkpoint_data); + ActorRegistration(const ActorTableData &actor_table_data, + const ActorCheckpointData &checkpoint_data); /// Each actor may have multiple callers, or "handles". A frontier leaf /// represents the execution state of the actor with respect to a single @@ -46,15 +50,15 @@ class ActorRegistration { /// Get the actor table data. /// /// \return The actor table data. - const ActorTableDataT &GetTableData() const { return actor_table_data_; } + const ActorTableData &GetTableData() const { return actor_table_data_; } /// Get the actor's current state (ALIVE or DEAD). /// /// \return The actor's current state. - const ActorState &GetState() const { return actor_table_data_.state; } + const ActorState GetState() const { return actor_table_data_.state(); } /// Update actor's state. - void SetState(const ActorState &state) { actor_table_data_.state = state; } + void SetState(const ActorState &state) { actor_table_data_.set_state(state); } /// Get the actor's node manager location. /// @@ -131,13 +135,13 @@ class ActorRegistration { /// \param actor_id ID of this actor. /// \param task The task that just finished on the actor. /// \return A shared pointer to the generated checkpoint data. - std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, - const Task &task); + std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, + const Task &task); private: /// Information from the global actor table about this actor, including the /// node manager location. - ActorTableDataT actor_table_data_; + ActorTableData actor_table_data_; /// The object representing the state following the actor's most recently /// executed task. The next task to execute on the actor should be marked as /// execution-dependent on this object. diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 32dddada52444..68d5aa817c2bb 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -63,15 +63,6 @@ void LineageEntry::UpdateTaskData(const Task &task) { Lineage::Lineage() {} -Lineage::Lineage(const protocol::ForwardTaskRequest &task_request) { - // Deserialize and set entries for the uncommitted tasks. - auto tasks = task_request.uncommitted_tasks(); - for (auto it = tasks->begin(); it != tasks->end(); it++) { - const auto &task = **it; - RAY_CHECK(SetEntry(task, GcsStatus::UNCOMMITTED)); - } -} - boost::optional Lineage::GetEntry(const TaskID &task_id) const { auto entry = entries_.find(task_id); if (entry != entries_.end()) { @@ -151,20 +142,6 @@ const std::unordered_map &Lineage::GetEntries() cons return entries_; } -flatbuffers::Offset Lineage::ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb, const TaskID &task_id) const { - RAY_CHECK(GetEntry(task_id)); - // Serialize the task and object entries. - std::vector> uncommitted_tasks; - for (const auto &entry : entries_) { - uncommitted_tasks.push_back(entry.second.TaskData().ToFlatbuffer(fbb)); - } - - auto request = protocol::CreateForwardTaskRequest(fbb, to_flatbuf(fbb, task_id), - fbb.CreateVector(uncommitted_tasks)); - return request; -} - const std::unordered_set &Lineage::GetChildren(const TaskID &task_id) const { static const std::unordered_set empty_children; const auto it = children_.find(task_id); @@ -176,7 +153,7 @@ const std::unordered_set &Lineage::GetChildren(const TaskID &task_id) co } LineageCache::LineageCache(const ClientID &client_id, - gcs::TableInterface &task_storage, + gcs::TableInterface &task_storage, gcs::PubsubInterface &task_pubsub, uint64_t max_lineage_size) : client_id_(client_id), task_storage_(task_storage), task_pubsub_(task_pubsub) {} @@ -292,15 +269,11 @@ void LineageCache::FlushTask(const TaskID &task_id) { gcs::raylet::TaskTable::WriteCallback task_callback = [this](ray::gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { HandleEntryCommitted(id); }; + const TaskTableData &data) { HandleEntryCommitted(id); }; auto task = lineage_.GetEntry(task_id); // TODO(swang): Make this better... - flatbuffers::FlatBufferBuilder fbb; - auto message = task->TaskData().ToFlatbuffer(fbb); - fbb.Finish(message); - auto task_data = std::make_shared(); - auto root = flatbuffers::GetRoot(fbb.GetBufferPointer()); - root->UnPackTo(task_data.get()); + auto task_data = std::make_shared(); + task_data->set_task(task->TaskData().Serialize()); RAY_CHECK_OK( task_storage_.Add(DriverID(task->TaskData().GetTaskSpecification().DriverId()), task_id, task_data, task_callback)); @@ -365,8 +338,6 @@ void LineageCache::EvictTask(const TaskID &task_id) { for (const auto &child_id : children) { EvictTask(child_id); } - - return; } void LineageCache::HandleEntryCommitted(const TaskID &task_id) { diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index 5436fa372fa4c..37ce5caf65075 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -4,18 +4,17 @@ #include #include -// clang-format off -#include "ray/common/common_protocol.h" -#include "ray/raylet/task.h" -#include "ray/gcs/tables.h" #include "ray/common/id.h" #include "ray/common/status.h" -// clang-format on +#include "ray/gcs/tables.h" +#include "ray/raylet/task.h" namespace ray { namespace raylet { +using rpc::TaskTableData; + /// The status of a lineage cache entry according to its status in the GCS. /// Tasks can only transition to a higher GcsStatus (e.g., an UNCOMMITTED state /// can become COMMITTING but not vice versa). If a task is evicted from the @@ -136,12 +135,6 @@ class Lineage { /// Construct an empty Lineage. Lineage(); - /// Construct a Lineage from a ForwardTaskRequest. - /// - /// \param task_request The request to construct the lineage from. All - /// uncommitted tasks in the request will be added to the lineage. - Lineage(const protocol::ForwardTaskRequest &task_request); - /// Get an entry from the lineage. /// /// \param entry_id The ID of the entry to get. @@ -172,15 +165,6 @@ class Lineage { /// \return A const reference to the lineage entries. const std::unordered_map &GetEntries() const; - /// Serialize this lineage to a ForwardTaskRequest flatbuffer. - /// - /// \param entry_id The task ID to include in the ForwardTaskRequest - /// flatbuffer. - /// \return An offset to the serialized lineage. The serialization includes - /// all task and object entries in the lineage. - flatbuffers::Offset ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb, const TaskID &entry_id) const; - /// Return the IDs of tasks in the lineage that are dependent on the given /// task. /// @@ -221,7 +205,7 @@ class LineageCache { /// Create a lineage cache for the given task storage system. /// TODO(swang): Pass in the policy (interface?). LineageCache(const ClientID &client_id, - gcs::TableInterface &task_storage, + gcs::TableInterface &task_storage, gcs::PubsubInterface &task_pubsub, uint64_t max_lineage_size); /// Asynchronously commit a task to the GCS. @@ -319,7 +303,7 @@ class LineageCache { /// TODO(swang): Move the ClientID into the generic Table implementation. ClientID client_id_; /// The durable storage system for task information. - gcs::TableInterface &task_storage_; + gcs::TableInterface &task_storage_; /// The pubsub storage system for task information. This can be used to /// request notifications for the commit of a task entry. gcs::PubsubInterface &task_pubsub_; diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 43e64e4002925..a6184902f803b 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -13,7 +13,7 @@ namespace ray { namespace raylet { -class MockGcs : public gcs::TableInterface, +class MockGcs : public gcs::TableInterface, public gcs::PubsubInterface { public: MockGcs() {} @@ -23,15 +23,15 @@ class MockGcs : public gcs::TableInterface, } Status Add(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, - const gcs::TableInterface::WriteCallback &done) { + std::shared_ptr &task_data, + const gcs::TableInterface::WriteCallback &done) { task_table_[task_id] = task_data; auto callback = done; // If we requested notifications for this task ID, send the notification as // part of the callback. if (subscribed_tasks_.count(task_id) == 1) { callback = [this, done](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const protocol::TaskT &data) { + const TaskTableData &data) { done(client, task_id, data); // If we're subscribed to the task to be added, also send a // subscription notification. @@ -45,14 +45,14 @@ class MockGcs : public gcs::TableInterface, return ray::Status::OK(); } - Status RemoteAdd(const TaskID &task_id, std::shared_ptr task_data) { + Status RemoteAdd(const TaskID &task_id, std::shared_ptr task_data) { task_table_[task_id] = task_data; // Send a notification after the add if the lineage cache requested // notifications for this key. bool send_notification = (subscribed_tasks_.count(task_id) == 1); auto callback = [this, send_notification](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const protocol::TaskT &data) { + const TaskTableData &data) { if (send_notification) { notification_callback_(client, task_id, data); } @@ -84,7 +84,7 @@ class MockGcs : public gcs::TableInterface, } } - const std::unordered_map> &TaskTable() const { + const std::unordered_map> &TaskTable() const { return task_table_; } @@ -95,7 +95,7 @@ class MockGcs : public gcs::TableInterface, const int NumTaskAdds() const { return num_task_adds_; } private: - std::unordered_map> task_table_; + std::unordered_map> task_table_; std::vector> callbacks_; gcs::raylet::TaskTable::WriteCallback notification_callback_; std::unordered_set subscribed_tasks_; @@ -111,7 +111,7 @@ class LineageCacheTest : public ::testing::Test { mock_gcs_(), lineage_cache_(ClientID::FromRandom(), mock_gcs_, mock_gcs_, max_lineage_size_) { mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const ray::protocol::TaskT &data) { + const TaskTableData &data) { lineage_cache_.HandleEntryCommitted(task_id); num_notifications_++; }); @@ -341,7 +341,7 @@ TEST_F(LineageCacheTest, TestEvictChain) { ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), tasks.size()); // Simulate executing the task on a remote node and adding it to the GCS. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK( mock_gcs_.RemoteAdd(tasks.at(1).GetTaskSpecification().TaskId(), task_data)); mock_gcs_.Flush(); @@ -432,7 +432,7 @@ TEST_F(LineageCacheTest, TestEviction) { // Simulate executing the first task on a remote node and adding it to the // GCS. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); auto it = tasks.begin(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); it++; @@ -490,7 +490,7 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { auto last_task = tasks.front(); tasks.erase(tasks.begin()); for (auto it = tasks.rbegin(); it != tasks.rend(); it++) { - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); // Check that the remote task is flushed. num_tasks_flushed++; @@ -500,7 +500,7 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { } // Flush the last task. The lineage should not get evicted until this task's // commit is received. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(last_task.GetTaskSpecification().TaskId(), task_data)); num_tasks_flushed++; mock_gcs_.Flush(); @@ -536,7 +536,7 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { // until after the final remote task is executed, since a task can only be // evicted once all of its ancestors have been committed. for (auto it = tasks.rbegin(); it != tasks.rend(); it++) { - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), lineage_size * 2); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); num_tasks_flushed++; diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 62ecb00b819f7..0a853260887e7 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -24,14 +24,14 @@ Monitor::Monitor(boost::asio::io_service &io_service, const std::string &redis_a } void Monitor::HandleHeartbeat(const ClientID &client_id, - const HeartbeatTableDataT &heartbeat_data) { + const HeartbeatTableData &heartbeat_data) { heartbeats_[client_id] = num_heartbeats_timeout_; heartbeat_buffer_[client_id] = heartbeat_data; } void Monitor::Start() { const auto heartbeat_callback = [this](gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatTableDataT &heartbeat_data) { + const HeartbeatTableData &heartbeat_data) { HandleHeartbeat(id, heartbeat_data); }; RAY_CHECK_OK(gcs_client_.heartbeat_table().Subscribe( @@ -49,11 +49,11 @@ void Monitor::Tick() { RAY_LOG(WARNING) << "Client timed out: " << client_id; auto lookup_callback = [this, client_id]( gcs::AsyncGcsClient *client, const ClientID &id, - const std::vector &all_data) { + const std::vector &all_data) { bool marked = false; for (const auto &data : all_data) { - if (client_id.Binary() == data.client_id && - data.entry_type == EntryType::DELETION) { + if (client_id.Binary() == data.client_id() && + data.entry_type() == ClientTableData::DELETION) { // The node has been marked dead by itself. marked = true; } @@ -84,10 +84,9 @@ void Monitor::Tick() { // Send any buffered heartbeats as a single publish. if (!heartbeat_buffer_.empty()) { - auto batch = std::make_shared(); + auto batch = std::make_shared(); for (const auto &heartbeat : heartbeat_buffer_) { - batch->batch.push_back(std::unique_ptr( - new HeartbeatTableDataT(heartbeat.second))); + batch->add_batch()->CopyFrom(heartbeat.second); } RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(DriverID::Nil(), ClientID::Nil(), batch, nullptr)); diff --git a/src/ray/raylet/monitor.h b/src/ray/raylet/monitor.h index c69cc9f003e0a..5725e52cf495b 100644 --- a/src/ray/raylet/monitor.h +++ b/src/ray/raylet/monitor.h @@ -11,6 +11,10 @@ namespace ray { namespace raylet { +using rpc::ClientTableData; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatTableData; + class Monitor { public: /// Create a Raylet monitor attached to the given GCS address and port. @@ -35,7 +39,7 @@ class Monitor { /// \param client_id The client ID of the Raylet that sent the heartbeat. /// \param heartbeat_data The heartbeat sent by the client. void HandleHeartbeat(const ClientID &client_id, - const HeartbeatTableDataT &heartbeat_data); + const HeartbeatTableData &heartbeat_data); private: /// A client to the GCS, through which heartbeats are received. @@ -50,7 +54,7 @@ class Monitor { /// The Raylets that have been marked as dead in the client table. std::unordered_set dead_clients_; /// A buffer containing heartbeats received from node managers in the last tick. - std::unordered_map heartbeat_buffer_; + std::unordered_map heartbeat_buffer_; }; } // namespace raylet diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index fc364539cccea..808eeb6fd2110 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -46,9 +46,9 @@ ActorStats GetActorStatisticalData( std::unordered_map actor_registry) { ActorStats item; for (auto &pair : actor_registry) { - if (pair.second.GetState() == ActorState::ALIVE) { + if (pair.second.GetState() == ray::rpc::ActorTableData::ALIVE) { item.live_actors += 1; - } else if (pair.second.GetState() == ActorState::RECONSTRUCTING) { + } else if (pair.second.GetState() == ray::rpc::ActorTableData::RECONSTRUCTING) { item.reconstructing_actors += 1; } else { item.dead_actors += 1; @@ -130,7 +130,7 @@ ray::Status NodeManager::RegisterGcs() { // that were executed remotely. const auto task_committed_callback = [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const ray::protocol::TaskT &task_data) { + const TaskTableData &task_data) { lineage_cache_.HandleEntryCommitted(task_id); }; RAY_RETURN_NOT_OK(gcs_client_->raylet_task_table().Subscribe( @@ -139,8 +139,8 @@ ray::Status NodeManager::RegisterGcs() { const auto task_lease_notification_callback = [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskLeaseDataT &task_lease) { - const ClientID node_manager_id = ClientID::FromBinary(task_lease.node_manager_id); + const TaskLeaseData &task_lease) { + const ClientID node_manager_id = ClientID::FromBinary(task_lease.node_manager_id()); if (gcs_client_->client_table().IsRemoved(node_manager_id)) { // The node manager that added the task lease is already removed. The // lease is considered inactive. @@ -150,7 +150,7 @@ ray::Status NodeManager::RegisterGcs() { // expiration period since the entry may have been in the GCS for some // time already. For a more accurate estimate, the age of the entry in // the GCS should be subtracted from task_lease.timeout. - reconstruction_policy_.HandleTaskLeaseNotification(task_id, task_lease.timeout); + reconstruction_policy_.HandleTaskLeaseNotification(task_id, task_lease.timeout()); } }; const auto task_lease_empty_callback = [this](gcs::AsyncGcsClient *client, @@ -164,7 +164,7 @@ ray::Status NodeManager::RegisterGcs() { // Register a callback to handle actor notifications. auto actor_notification_callback = [this](gcs::AsyncGcsClient *client, const ActorID &actor_id, - const std::vector &data) { + const std::vector &data) { if (!data.empty()) { // We only need the last entry, because it represents the latest state of // this actor. @@ -177,34 +177,34 @@ ray::Status NodeManager::RegisterGcs() { // Register a callback on the client table for new clients. auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { + const ClientTableData &data) { ClientAdded(data); }; gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added); // Register a callback on the client table for removed clients. auto node_manager_client_removed = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ClientRemoved(data); }; + const ClientTableData &data) { ClientRemoved(data); }; gcs_client_->client_table().RegisterClientRemovedCallback(node_manager_client_removed); // Register a callback on the client table for resource create/update requests auto node_manager_resource_createupdated = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ResourceCreateUpdated(data); }; + const ClientTableData &data) { ResourceCreateUpdated(data); }; gcs_client_->client_table().RegisterResourceCreateUpdatedCallback( node_manager_resource_createupdated); // Register a callback on the client table for resource delete requests auto node_manager_resource_deleted = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ResourceDeleted(data); }; + const ClientTableData &data) { ResourceDeleted(data); }; gcs_client_->client_table().RegisterResourceDeletedCallback( node_manager_resource_deleted); // Subscribe to heartbeat batches from the monitor. const auto &heartbeat_batch_added = [this](gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatBatchTableDataT &heartbeat_batch) { + const HeartbeatBatchTableData &heartbeat_batch) { HeartbeatBatchAdded(heartbeat_batch); }; RAY_RETURN_NOT_OK(gcs_client_->heartbeat_batch_table().Subscribe( @@ -215,7 +215,7 @@ ray::Status NodeManager::RegisterGcs() { // Subscribe to driver table updates. const auto driver_table_handler = [this](gcs::AsyncGcsClient *client, const DriverID &client_id, - const std::vector &driver_data) { + const std::vector &driver_data) { HandleDriverTableUpdate(client_id, driver_data); }; RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe( @@ -251,12 +251,12 @@ void NodeManager::KillWorker(std::shared_ptr worker) { } void NodeManager::HandleDriverTableUpdate( - const DriverID &id, const std::vector &driver_data) { + const DriverID &id, const std::vector &driver_data) { for (const auto &entry : driver_data) { - RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::FromBinary(entry.driver_id) - << " " << entry.is_dead; - if (entry.is_dead) { - auto driver_id = DriverID::FromBinary(entry.driver_id); + RAY_LOG(DEBUG) << "HandleDriverTableUpdate " + << UniqueID::FromBinary(entry.driver_id()) << " " << entry.is_dead(); + if (entry.is_dead()) { + auto driver_id = DriverID::FromBinary(entry.driver_id()); auto workers = worker_pool_.GetWorkersRunningTasksForDriver(driver_id); // Kill all the workers. The actual cleanup for these workers is done @@ -288,26 +288,26 @@ void NodeManager::Heartbeat() { last_heartbeat_at_ms_ = now_ms; auto &heartbeat_table = gcs_client_->heartbeat_table(); - auto heartbeat_data = std::make_shared(); + auto heartbeat_data = std::make_shared(); const auto &my_client_id = gcs_client_->client_table().GetLocalClientId(); SchedulingResources &local_resources = cluster_resource_map_[my_client_id]; - heartbeat_data->client_id = my_client_id.Binary(); + heartbeat_data->set_client_id(my_client_id.Binary()); // TODO(atumanov): modify the heartbeat table protocol to use the ResourceSet directly. // TODO(atumanov): implement a ResourceSet const_iterator. for (const auto &resource_pair : local_resources.GetAvailableResources().GetResourceMap()) { - heartbeat_data->resources_available_label.push_back(resource_pair.first); - heartbeat_data->resources_available_capacity.push_back(resource_pair.second); + heartbeat_data->add_resources_available_label(resource_pair.first); + heartbeat_data->add_resources_available_capacity(resource_pair.second); } for (const auto &resource_pair : local_resources.GetTotalResources().GetResourceMap()) { - heartbeat_data->resources_total_label.push_back(resource_pair.first); - heartbeat_data->resources_total_capacity.push_back(resource_pair.second); + heartbeat_data->add_resources_total_label(resource_pair.first); + heartbeat_data->add_resources_total_capacity(resource_pair.second); } local_resources.SetLoadResources(local_queues_.GetResourceLoad()); for (const auto &resource_pair : local_resources.GetLoadResources().GetResourceMap()) { - heartbeat_data->resource_load_label.push_back(resource_pair.first); - heartbeat_data->resource_load_capacity.push_back(resource_pair.second); + heartbeat_data->add_resource_load_label(resource_pair.first); + heartbeat_data->add_resource_load_capacity(resource_pair.second); } ray::Status status = heartbeat_table.Add( @@ -335,13 +335,8 @@ void NodeManager::GetObjectManagerProfileInfo() { auto profile_info = object_manager_.GetAndResetProfilingInfo(); - if (profile_info.profile_events.size() > 0) { - flatbuffers::FlatBufferBuilder fbb; - auto message = CreateProfileTableData(fbb, &profile_info); - fbb.Finish(message); - auto profile_message = flatbuffers::GetRoot(fbb.GetBufferPointer()); - - RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*profile_message)); + if (profile_info.profile_events_size() > 0) { + RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_info)); } // Reset the timer. @@ -358,8 +353,8 @@ void NodeManager::GetObjectManagerProfileInfo() { } } -void NodeManager::ClientAdded(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id); +void NodeManager::ClientAdded(const ClientTableData &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); RAY_LOG(DEBUG) << "[ClientAdded] Received callback from client id " << client_id; if (client_id == gcs_client_->client_table().GetLocalClientId()) { @@ -378,19 +373,20 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { // Initialize a rpc client to the new node manager. std::unique_ptr client( - new rpc::NodeManagerClient(client_data.node_manager_address, - client_data.node_manager_port, client_call_manager_)); + new rpc::NodeManagerClient(client_data.node_manager_address(), + client_data.node_manager_port(), client_call_manager_)); remote_node_manager_clients_.emplace(client_id, std::move(client)); - ResourceSet resources_total(client_data.resources_total_label, - client_data.resources_total_capacity); + ResourceSet resources_total( + rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total)); } -void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { +void NodeManager::ClientRemoved(const ClientTableData &client_data) { // TODO(swang): If we receive a notification for our own death, clean up and // exit immediately. - const ClientID client_id = ClientID::FromBinary(client_data.client_id); + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); RAY_LOG(DEBUG) << "[ClientRemoved] Received callback from client id " << client_id; RAY_CHECK(client_id != gcs_client_->client_table().GetLocalClientId()) @@ -418,7 +414,7 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // TODO(swang): This could be very slow if there are many actors. for (const auto &actor_entry : actor_registry_) { if (actor_entry.second.GetNodeManagerId() == client_id && - actor_entry.second.GetState() == ActorState::ALIVE) { + actor_entry.second.GetState() == ActorTableData::ALIVE) { RAY_LOG(INFO) << "Actor " << actor_entry.first << " is disconnected, because its node " << client_id << " is removed from cluster. It may be reconstructed."; @@ -436,14 +432,15 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { lineage_cache_.FlushAllUncommittedTasks(); } -void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id); +void NodeManager::ResourceCreateUpdated(const ClientTableData &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); RAY_LOG(DEBUG) << "[ResourceCreateUpdated] received callback from client id " << client_id << ". Updating resource map."; - ResourceSet new_res_set(client_data.resources_total_label, - client_data.resources_total_capacity); + ResourceSet new_res_set( + rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); const ResourceSet &old_res_set = cluster_resource_map_[client_id].GetTotalResources(); ResourceSet difference_set = old_res_set.FindUpdatedResources(new_res_set); @@ -472,12 +469,13 @@ void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { return; } -void NodeManager::ResourceDeleted(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id); +void NodeManager::ResourceDeleted(const ClientTableData &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); - ResourceSet new_res_set(client_data.resources_total_label, - client_data.resources_total_capacity); + ResourceSet new_res_set( + rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); RAY_LOG(DEBUG) << "[ResourceDeleted] received callback from client id " << client_id << " with new resources: " << new_res_set.ToString() << ". Updating resource map."; @@ -523,7 +521,7 @@ void NodeManager::TryLocalInfeasibleTaskScheduling() { } void NodeManager::HeartbeatAdded(const ClientID &client_id, - const HeartbeatTableDataT &heartbeat_data) { + const HeartbeatTableData &heartbeat_data) { // Locate the client id in remote client table and update available resources based on // the received heartbeat information. auto it = cluster_resource_map_.find(client_id); @@ -535,10 +533,12 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, } SchedulingResources &remote_resources = it->second; - ResourceSet remote_available(heartbeat_data.resources_available_label, - heartbeat_data.resources_available_capacity); - ResourceSet remote_load(heartbeat_data.resource_load_label, - heartbeat_data.resource_load_capacity); + ResourceSet remote_available( + rpc::VectorFromProtobuf(heartbeat_data.resources_total_label()), + rpc::VectorFromProtobuf(heartbeat_data.resources_total_capacity())); + ResourceSet remote_load( + rpc::VectorFromProtobuf(heartbeat_data.resource_load_label()), + rpc::VectorFromProtobuf(heartbeat_data.resource_load_capacity())); // TODO(atumanov): assert that the load is a non-empty ResourceSet. remote_resources.SetAvailableResources(std::move(remote_available)); // Extract the load information and save it locally. @@ -563,40 +563,41 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, } } -void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableDataT &heartbeat_batch) { +void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_batch) { const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); // Update load information provided by each heartbeat. - for (const auto &heartbeat_data : heartbeat_batch.batch) { - const ClientID &client_id = ClientID::FromBinary(heartbeat_data->client_id); + for (const auto &heartbeat_data : heartbeat_batch.batch()) { + const ClientID &client_id = ClientID::FromBinary(heartbeat_data.client_id()); if (client_id == local_client_id) { // Skip heartbeats from self. continue; } - HeartbeatAdded(client_id, *heartbeat_data); + HeartbeatAdded(client_id, heartbeat_data); } } void NodeManager::PublishActorStateTransition( - const ActorID &actor_id, const ActorTableDataT &data, + const ActorID &actor_id, const ActorTableData &data, const ray::gcs::ActorTable::WriteCallback &failure_callback) { // Copy the actor notification data. - auto actor_notification = std::make_shared(data); + auto actor_notification = std::make_shared(data); // The actor log starts with an ALIVE entry. This is followed by 0 to N pairs // of (RECONSTRUCTING, ALIVE) entries, where N is the maximum number of // reconstructions. This is followed optionally by a DEAD entry. - int log_length = 2 * (actor_notification->max_reconstructions - - actor_notification->remaining_reconstructions); - if (actor_notification->state != ActorState::ALIVE) { + int log_length = 2 * (actor_notification->max_reconstructions() - + actor_notification->remaining_reconstructions()); + if (actor_notification->state() != ActorTableData::ALIVE) { // RECONSTRUCTING or DEAD entries have an odd index. log_length += 1; } // If we successful appended a record to the GCS table of the actor that // has died, signal this to anyone receiving signals from this actor. auto success_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { auto redis_context = client->primary_context(); - if (data.state == ActorState::DEAD || data.state == ActorState::RECONSTRUCTING) { + if (data.state() == ActorTableData::DEAD || + data.state() == ActorTableData::RECONSTRUCTING) { std::vector args = {"XADD", id.Hex(), "*", "signal", "ACTOR_DIED_SIGNAL"}; RAY_CHECK_OK(redis_context->RunArgvAsync(args)); @@ -633,11 +634,12 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, } RAY_LOG(DEBUG) << "Actor notification received: actor_id = " << actor_id << ", node_manager_id = " << actor_registration.GetNodeManagerId() - << ", state = " << EnumNameActorState(actor_registration.GetState()) + << ", state = " + << ActorTableData::ActorState_Name(actor_registration.GetState()) << ", remaining_reconstructions = " << actor_registration.GetRemainingReconstructions(); - if (actor_registration.GetState() == ActorState::ALIVE) { + if (actor_registration.GetState() == ActorTableData::ALIVE) { // The actor's location is now known. Dequeue any methods that were // submitted before the actor's location was known. // (See design_docs/task_states.rst for the state transition diagram.) @@ -664,7 +666,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, // empty lineage this time. SubmitTask(method, Lineage()); } - } else if (actor_registration.GetState() == ActorState::DEAD) { + } else if (actor_registration.GetState() == ActorTableData::DEAD) { // When an actor dies, loop over all of the queued tasks for that actor // and treat them as failed. auto tasks_to_remove = local_queues_.GetTaskIdsForActor(actor_id); @@ -673,7 +675,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, TreatTaskAsFailed(task, ErrorType::ACTOR_DIED); } } else { - RAY_CHECK(actor_registration.GetState() == ActorState::RECONSTRUCTING); + RAY_CHECK(actor_registration.GetState() == ActorTableData::RECONSTRUCTING); RAY_LOG(DEBUG) << "Actor is being reconstructed: " << actor_id; // When an actor fails but can be reconstructed, resubmit all of the queued // tasks for that actor. This will mark the tasks as waiting for actor @@ -794,8 +796,20 @@ void NodeManager::ProcessClientMessage( ProcessPushErrorRequestMessage(message_data); } break; case protocol::MessageType::PushProfileEventsRequest: { - auto message = flatbuffers::GetRoot(message_data); - RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*message)); + ProfileTableDataT fbs_message; + flatbuffers::GetRoot(message_data)->UnPackTo(&fbs_message); + rpc::ProfileTableData profile_table_data; + profile_table_data.set_component_type(fbs_message.component_type); + profile_table_data.set_component_id(fbs_message.component_id); + for (const auto &fbs_event : fbs_message.profile_events) { + rpc::ProfileTableData::ProfileEvent *event = + profile_table_data.add_profile_events(); + event->set_event_type(fbs_event->event_type); + event->set_start_time(fbs_event->start_time); + event->set_end_time(fbs_event->end_time); + event->set_extra_data(fbs_event->extra_data); + } + RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_table_data)); } break; case protocol::MessageType::FreeObjectsInObjectStoreRequest: { auto message = flatbuffers::GetRoot(message_data); @@ -863,8 +877,8 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca // Check if this actor needs to be reconstructed. ActorState new_state = actor_registration.GetRemainingReconstructions() > 0 && !intentional_disconnect - ? ActorState::RECONSTRUCTING - : ActorState::DEAD; + ? ActorTableData::RECONSTRUCTING + : ActorTableData::DEAD; if (was_local) { // Clean up the dummy objects from this actor. RAY_LOG(DEBUG) << "Removing dummy objects for actor: " << actor_id; @@ -873,8 +887,8 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca } } // Update the actor's state. - ActorTableDataT new_actor_data = actor_entry->second.GetTableData(); - new_actor_data.state = new_state; + ActorTableData new_actor_data = actor_entry->second.GetTableData(); + new_actor_data.set_state(new_state); if (was_local) { // If the actor was local, immediately update the state in actor registry. // So if we receive any actor tasks before we receive GCS notification, @@ -885,7 +899,7 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca ray::gcs::ActorTable::WriteCallback failure_callback = nullptr; if (was_local) { failure_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { // If the disconnected actor was local, only this node will try to update actor // state. So the update shouldn't fail. RAY_LOG(FATAL) << "Failed to update state for actor " << id; @@ -1160,7 +1174,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( DriverID::Nil(), checkpoint_id, checkpoint_data, [worker, actor_id, this](ray::gcs::AsyncGcsClient *client, const ActorCheckpointID &checkpoint_id, - const ActorCheckpointDataT &data) { + const ActorCheckpointData &data) { RAY_LOG(DEBUG) << "Checkpoint " << checkpoint_id << " saved for actor " << worker->GetActorId(); // Save this actor-to-checkpoint mapping, and remove old checkpoints associated @@ -1244,19 +1258,19 @@ void NodeManager::ProcessSetResourceRequest( return; } - // Add the new resource to a skeleton ClientTableDataT object - ClientTableDataT data; + // Add the new resource to a skeleton ClientTableData object + ClientTableData data; gcs_client_->client_table().GetClient(client_id, data); // Replace the resource vectors with the resource deltas from the message. // RES_CREATEUPDATE and RES_DELETE entries in the ClientTable track changes (deltas) in // the resources - data.resources_total_label = std::vector{resource_name}; - data.resources_total_capacity = std::vector{capacity}; + data.add_resources_total_label(resource_name); + data.add_resources_total_capacity(capacity); // Set the correct flag for entry_type if (is_deletion) { - data.entry_type = EntryType::RES_DELETE; + data.set_entry_type(ClientTableData::RES_DELETE); } else { - data.entry_type = EntryType::RES_CREATEUPDATE; + data.set_entry_type(ClientTableData::RES_CREATEUPDATE); } // Submit to the client table. This calls the ResourceCreateUpdated callback, which @@ -1265,7 +1279,7 @@ void NodeManager::ProcessSetResourceRequest( if (not worker) { worker = worker_pool_.GetRegisteredDriver(client); } - auto data_shared_ptr = std::make_shared(data); + auto data_shared_ptr = std::make_shared(data); auto client_table = gcs_client_->client_table(); RAY_CHECK_OK(gcs_client_->client_table().Append( DriverID::Nil(), client_table.client_log_key_, data_shared_ptr, nullptr)); @@ -1370,7 +1384,7 @@ bool NodeManager::CheckDependencyManagerInvariant() const { void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_type) { const TaskSpecification &spec = task.GetTaskSpecification(); RAY_LOG(DEBUG) << "Treating task " << spec.TaskId() << " as failed because of error " - << EnumNameErrorType(error_type) << "."; + << ErrorType_Name(error_type) << "."; // If this was an actor creation task that tried to resume from a checkpoint, // then erase it here since the task did not finish. if (spec.IsActorCreationTask()) { @@ -1488,9 +1502,9 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // If we have already seen this actor and this actor is not being reconstructed, // its location is known. bool location_known = - seen && actor_entry->second.GetState() != ActorState::RECONSTRUCTING; + seen && actor_entry->second.GetState() != ActorTableData::RECONSTRUCTING; if (location_known) { - if (actor_entry->second.GetState() == ActorState::DEAD) { + if (actor_entry->second.GetState() == ActorTableData::DEAD) { // If this actor is dead, either because the actor process is dead // or because its residing node is dead, treat this task as failed. TreatTaskAsFailed(task, ErrorType::ACTOR_DIED); @@ -1535,7 +1549,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // we missed the creation notification. auto lookup_callback = [this](gcs::AsyncGcsClient *client, const ActorID &actor_id, - const std::vector &data) { + const std::vector &data) { if (!data.empty()) { // The actor has been created. We only need the last entry, because // it represents the latest state of this actor. @@ -1861,11 +1875,11 @@ void NodeManager::FinishAssignedTask(Worker &worker) { } } -ActorTableDataT NodeManager::CreateActorTableDataFromCreationTask(const Task &task) { +ActorTableData NodeManager::CreateActorTableDataFromCreationTask(const Task &task) { RAY_CHECK(task.GetTaskSpecification().IsActorCreationTask()); auto actor_id = task.GetTaskSpecification().ActorCreationId(); auto actor_entry = actor_registry_.find(actor_id); - ActorTableDataT new_actor_data; + ActorTableData new_actor_data; // TODO(swang): If this is an actor that was reconstructed, and previous // actor notifications were delayed, then this node may not have an entry for // the actor in actor_regisry_. Then, the fields for the number of @@ -1873,32 +1887,33 @@ ActorTableDataT NodeManager::CreateActorTableDataFromCreationTask(const Task &ta if (actor_entry == actor_registry_.end()) { // Set all of the static fields for the actor. These fields will not // change even if the actor fails or is reconstructed. - new_actor_data.actor_id = actor_id.Binary(); - new_actor_data.actor_creation_dummy_object_id = - task.GetTaskSpecification().ActorDummyObject().Binary(); - new_actor_data.driver_id = task.GetTaskSpecification().DriverId().Binary(); - new_actor_data.max_reconstructions = - task.GetTaskSpecification().MaxActorReconstructions(); + new_actor_data.set_actor_id(actor_id.Binary()); + new_actor_data.set_actor_creation_dummy_object_id( + task.GetTaskSpecification().ActorDummyObject().Binary()); + new_actor_data.set_driver_id(task.GetTaskSpecification().DriverId().Binary()); + new_actor_data.set_max_reconstructions( + task.GetTaskSpecification().MaxActorReconstructions()); // This is the first time that the actor has been created, so the number // of remaining reconstructions is the max. - new_actor_data.remaining_reconstructions = - task.GetTaskSpecification().MaxActorReconstructions(); + new_actor_data.set_remaining_reconstructions( + task.GetTaskSpecification().MaxActorReconstructions()); } else { // If we've already seen this actor, it means that this actor was reconstructed. // Thus, its previous state must be RECONSTRUCTING. - RAY_CHECK(actor_entry->second.GetState() == ActorState::RECONSTRUCTING); + RAY_CHECK(actor_entry->second.GetState() == ActorTableData::RECONSTRUCTING); // Copy the static fields from the current actor entry. new_actor_data = actor_entry->second.GetTableData(); // We are reconstructing the actor, so subtract its // remaining_reconstructions by 1. - new_actor_data.remaining_reconstructions--; + new_actor_data.set_remaining_reconstructions( + new_actor_data.remaining_reconstructions() - 1); } // Set the new fields for the actor's state to indicate that the actor is // now alive on this node manager. - new_actor_data.node_manager_id = - gcs_client_->client_table().GetLocalClientId().Binary(); - new_actor_data.state = ActorState::ALIVE; + new_actor_data.set_node_manager_id( + gcs_client_->client_table().GetLocalClientId().Binary()); + new_actor_data.set_state(ActorTableData::ALIVE); return new_actor_data; } @@ -1934,7 +1949,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { DriverID::Nil(), checkpoint_id, [this, actor_id, new_actor_data](ray::gcs::AsyncGcsClient *client, const UniqueID &checkpoint_id, - const ActorCheckpointDataT &checkpoint_data) { + const ActorCheckpointData &checkpoint_data) { RAY_LOG(INFO) << "Restoring registration for actor " << actor_id << " from checkpoint " << checkpoint_id; ActorRegistration actor_registration = @@ -1948,7 +1963,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { actor_id, new_actor_data, /*failure_callback=*/ [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { // Only one node at a time should succeed at creating the actor. RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; }); @@ -1964,8 +1979,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { PublishActorStateTransition( actor_id, new_actor_data, /*failure_callback=*/ - [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ActorID &id, const ActorTableData &data) { // Only one node at a time should succeed at creating the actor. RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; }); @@ -2004,10 +2018,11 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { DriverID::Nil(), task_id, /*success_callback=*/ [this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const ray::protocol::TaskT &task_data) { + const TaskTableData &task_data) { // The task was in the GCS task table. Use the stored task spec to // re-execute the task. - const Task task(task_data); + auto message = flatbuffers::GetRoot(task_data.task().data()); + const Task task(*message); ResubmitTask(task); }, /*failure_callback=*/ @@ -2035,7 +2050,7 @@ void NodeManager::ResubmitTask(const Task &task) { if (task.GetTaskSpecification().IsActorCreationTask()) { const auto &actor_id = task.GetTaskSpecification().ActorCreationId(); const auto it = actor_registry_.find(actor_id); - if (it != actor_registry_.end() && it->second.GetState() == ActorState::ALIVE) { + if (it != actor_registry_.end() && it->second.GetState() == ActorTableData::ALIVE) { // If the actor is still alive, then do not resubmit the task. If the // actor actually is dead and a result is needed, then reconstruction // for this task will be triggered again. diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 61613358330c8..f45c8b0355536 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -10,7 +10,6 @@ #include "ray/raylet/task.h" #include "ray/object_manager/object_manager.h" #include "ray/common/client_connection.h" -#include "ray/gcs/format/util.h" #include "ray/raylet/actor_registration.h" #include "ray/raylet/lineage_cache.h" #include "ray/raylet/scheduling_policy.h" @@ -26,6 +25,13 @@ namespace ray { namespace raylet { +using rpc::ActorTableData; +using rpc::ClientTableData; +using rpc::DriverTableData; +using rpc::ErrorType; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatTableData; + struct NodeManagerConfig { /// The node's resource configuration. ResourceSet resource_config; @@ -112,22 +118,22 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// /// \param data Data associated with the new client. /// \return Void. - void ClientAdded(const ClientTableDataT &data); + void ClientAdded(const ClientTableData &data); /// Handler for the removal of a GCS client. /// \param client_data Data associated with the removed client. /// \return Void. - void ClientRemoved(const ClientTableDataT &client_data); + void ClientRemoved(const ClientTableData &client_data); /// Handler for the addition or updation of a resource in the GCS /// \param client_data Data associated with the new client. /// \return Void. - void ResourceCreateUpdated(const ClientTableDataT &client_data); + void ResourceCreateUpdated(const ClientTableData &client_data); /// Handler for the deletion of a resource in the GCS /// \param client_data Data associated with the new client. /// \return Void. - void ResourceDeleted(const ClientTableDataT &client_data); + void ResourceDeleted(const ClientTableData &client_data); /// Evaluates the local infeasible queue to check if any tasks can be scheduled. /// This is called whenever there's an update to the resources on the local client. @@ -150,11 +156,11 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param id The ID of the node manager that sent the heartbeat. /// \param data The heartbeat data including load information. /// \return Void. - void HeartbeatAdded(const ClientID &id, const HeartbeatTableDataT &data); + void HeartbeatAdded(const ClientID &id, const HeartbeatTableData &data); /// Handler for a heartbeat batch notification from the GCS /// /// \param heartbeat_batch The batch of heartbeat data. - void HeartbeatBatchAdded(const HeartbeatBatchTableDataT &heartbeat_batch); + void HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_batch); /// Methods for task scheduling. @@ -206,7 +212,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// Helper function to produce actor table data for a newly created actor. /// /// \param task The actor creation task that created the actor. - ActorTableDataT CreateActorTableDataFromCreationTask(const Task &task); + ActorTableData CreateActorTableDataFromCreationTask(const Task &task); /// Handle a worker finishing an assigned actor task or actor creation task. /// \param worker The worker that finished the task. /// \param task The actor task or actor creationt ask. @@ -317,7 +323,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param failure_callback An optional callback to call if the publish is /// unsuccessful. void PublishActorStateTransition( - const ActorID &actor_id, const ActorTableDataT &data, + const ActorID &actor_id, const ActorTableData &data, const ray::gcs::ActorTable::WriteCallback &failure_callback); /// When a driver dies, loop over all of the queued tasks for that driver and @@ -346,7 +352,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param driver_data Data associated with a driver table event. /// \return Void. void HandleDriverTableUpdate(const DriverID &id, - const std::vector &driver_data); + const std::vector &driver_data); /// Check if certain invariants associated with the task dependency manager /// and the local queues are satisfied. This is only used for debugging diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 473e6c263ffea..cbf9b25213caf 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -90,23 +90,23 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, const NodeManagerConfig &node_manager_config) { RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); - ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); - client_info.node_manager_address = node_ip_address; - client_info.raylet_socket_name = raylet_socket_name; - client_info.object_store_socket_name = object_store_socket_name; - client_info.object_manager_port = object_manager_acceptor_.local_endpoint().port(); - client_info.node_manager_port = node_manager_.GetServerPort(); + ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); + client_info.set_node_manager_address(node_ip_address); + client_info.set_raylet_socket_name(raylet_socket_name); + client_info.set_object_store_socket_name(object_store_socket_name); + client_info.set_object_manager_port(object_manager_acceptor_.local_endpoint().port()); + client_info.set_node_manager_port(node_manager_.GetServerPort()); // Add resource information. for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) { - client_info.resources_total_label.push_back(resource_pair.first); - client_info.resources_total_capacity.push_back(resource_pair.second); + client_info.add_resources_total_label(resource_pair.first); + client_info.add_resources_total_capacity(resource_pair.second); } RAY_LOG(DEBUG) << "Node manager " << gcs_client_->client_table().GetLocalClientId() - << " started on " << client_info.node_manager_address << ":" - << client_info.node_manager_port << " object manager at " - << client_info.node_manager_address << ":" - << client_info.object_manager_port; + << " started on " << client_info.node_manager_address() << ":" + << client_info.node_manager_port() << " object manager at " + << client_info.node_manager_address() << ":" + << client_info.object_manager_port(); ; RAY_RETURN_NOT_OK(gcs_client_->client_table().Connect(client_info)); diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index 26fe74b2b6225..9367a5054591e 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -16,6 +16,8 @@ namespace ray { namespace raylet { +using rpc::ClientTableData; + class Task; class NodeManager; diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index 97c86ea73cd87..bf5c1acfaa377 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -106,19 +106,19 @@ void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id, // Attempt to reconstruct the task by inserting an entry into the task // reconstruction log. This will fail if another node has already inserted // an entry for this reconstruction. - auto reconstruction_entry = std::make_shared(); - reconstruction_entry->num_reconstructions = reconstruction_attempt; - reconstruction_entry->node_manager_id = client_id_.Binary(); + auto reconstruction_entry = std::make_shared(); + reconstruction_entry->set_num_reconstructions(reconstruction_attempt); + reconstruction_entry->set_node_manager_id(client_id_.Binary()); RAY_CHECK_OK(task_reconstruction_log_.AppendAt( DriverID::Nil(), task_id, reconstruction_entry, /*success_callback=*/ [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionDataT &data) { + const TaskReconstructionData &data) { HandleReconstructionLogAppend(task_id, /*success=*/true); }, /*failure_callback=*/ [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionDataT &data) { + const TaskReconstructionData &data) { HandleReconstructionLogAppend(task_id, /*success=*/false); }, reconstruction_attempt)); diff --git a/src/ray/raylet/reconstruction_policy.h b/src/ray/raylet/reconstruction_policy.h index cd969cc2706e0..a194443e14258 100644 --- a/src/ray/raylet/reconstruction_policy.h +++ b/src/ray/raylet/reconstruction_policy.h @@ -17,6 +17,8 @@ namespace ray { namespace raylet { +using rpc::TaskReconstructionData; + class ReconstructionPolicyInterface { public: virtual void ListenAndMaybeReconstruct(const ObjectID &object_id) = 0; diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 4ccebd0c0c09e..12d9336a382fb 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -14,6 +14,8 @@ namespace ray { namespace raylet { +using rpc::TaskLeaseData; + class MockObjectDirectory : public ObjectDirectoryInterface { public: MockObjectDirectory() {} @@ -83,7 +85,7 @@ class MockGcs : public gcs::PubsubInterface, } void Add(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_lease_data) { + std::shared_ptr &task_lease_data) { task_lease_table_[task_id] = task_lease_data; if (subscribed_tasks_.count(task_id) == 1) { notification_callback_(nullptr, task_id, *task_lease_data); @@ -110,7 +112,7 @@ class MockGcs : public gcs::PubsubInterface, Status AppendAt( const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, + std::shared_ptr &task_data, const ray::gcs::LogInterface::WriteCallback &success_callback, const ray::gcs::LogInterface::WriteCallback @@ -132,15 +134,15 @@ class MockGcs : public gcs::PubsubInterface, MOCK_METHOD4( Append, ray::Status( - const DriverID &, const TaskID &, std::shared_ptr &, + const DriverID &, const TaskID &, std::shared_ptr &, const ray::gcs::LogInterface::WriteCallback &)); private: gcs::TaskLeaseTable::WriteCallback notification_callback_; gcs::TaskLeaseTable::FailureCallback failure_callback_; - std::unordered_map> task_lease_table_; + std::unordered_map> task_lease_table_; std::unordered_set subscribed_tasks_; - std::unordered_map> + std::unordered_map> task_reconstruction_log_; }; @@ -159,9 +161,9 @@ class ReconstructionPolicyTest : public ::testing::Test { timer_canceled_(false) { mock_gcs_.Subscribe( [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskLeaseDataT &task_lease) { + const TaskLeaseData &task_lease) { reconstruction_policy_->HandleTaskLeaseNotification(task_id, - task_lease.timeout); + task_lease.timeout()); }, [this](gcs::AsyncGcsClient *client, const TaskID &task_id) { reconstruction_policy_->HandleTaskLeaseNotification(task_id, 0); @@ -314,10 +316,10 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { int64_t test_period = 2 * reconstruction_timeout_ms_; // Acquire the task lease for a period longer than the test period. - auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = ClientID::FromRandom().Binary(); - task_lease_data->acquired_at = current_sys_time_ms(); - task_lease_data->timeout = 2 * test_period; + auto task_lease_data = std::make_shared(); + task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary()); + task_lease_data->set_acquired_at(current_sys_time_ms()); + task_lease_data->set_timeout(2 * test_period); mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); // Listen for an object. @@ -328,7 +330,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { ASSERT_TRUE(reconstructed_tasks_.empty()); // Run the test again past the expiration time of the lease. - Run(task_lease_data->timeout * 1.1); + Run(task_lease_data->timeout() * 1.1); // Check that this time, reconstruction is triggered. ASSERT_EQ(reconstructed_tasks_[task_id], 1); } @@ -341,10 +343,10 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { reconstruction_policy_->ListenAndMaybeReconstruct(object_id); // Send the reconstruction manager heartbeats about the object. SetPeriodicTimer(reconstruction_timeout_ms_ / 2, [this, task_id]() { - auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = ClientID::FromRandom().Binary(); - task_lease_data->acquired_at = current_sys_time_ms(); - task_lease_data->timeout = reconstruction_timeout_ms_; + auto task_lease_data = std::make_shared(); + task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary()); + task_lease_data->set_acquired_at(current_sys_time_ms()); + task_lease_data->set_timeout(reconstruction_timeout_ms_); mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); }); // Run the test for much longer than the reconstruction timeout. @@ -393,14 +395,14 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { // Log a reconstruction attempt to simulate a different node attempting the // reconstruction first. This should suppress this node's first attempt at // reconstruction. - auto task_reconstruction_data = std::make_shared(); - task_reconstruction_data->node_manager_id = ClientID::FromRandom().Binary(); - task_reconstruction_data->num_reconstructions = 0; + auto task_reconstruction_data = std::make_shared(); + task_reconstruction_data->set_node_manager_id(ClientID::FromRandom().Binary()); + task_reconstruction_data->set_num_reconstructions(0); RAY_CHECK_OK( mock_gcs_.AppendAt(DriverID::Nil(), task_id, task_reconstruction_data, nullptr, /*failure_callback=*/ [](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionDataT &data) { ASSERT_TRUE(false); }, + const TaskReconstructionData &data) { ASSERT_TRUE(false); }, /*log_index=*/0)); // Listen for an object. diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index c5155b96b0c15..89028c733d0d6 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -261,10 +261,10 @@ void TaskDependencyManager::AcquireTaskLease(const TaskID &task_id) { << (it->second.expires_at - now_ms) << "ms"; } - auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = client_id_.Hex(); - task_lease_data->acquired_at = current_sys_time_ms(); - task_lease_data->timeout = it->second.lease_period; + auto task_lease_data = std::make_shared(); + task_lease_data->set_node_manager_id(client_id_.Hex()); + task_lease_data->set_acquired_at(current_sys_time_ms()); + task_lease_data->set_timeout(it->second.lease_period); RAY_CHECK_OK(task_lease_table_.Add(DriverID::Nil(), task_id, task_lease_data, nullptr)); auto period = boost::posix_time::milliseconds(it->second.lease_period / 2); diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index 3788a5eae7aed..a965582952348 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -13,6 +13,8 @@ namespace ray { namespace raylet { +using rpc::TaskLeaseData; + class ReconstructionPolicy; /// \class TaskDependencyManager diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index e0f832a128704..f7a60989fcba5 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -30,7 +30,7 @@ class MockGcs : public gcs::TableInterface { MOCK_METHOD4( Add, ray::Status(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, + std::shared_ptr &task_data, const gcs::TableInterface::WriteCallback &done)); }; diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 719378216fb7a..16086565de805 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -48,8 +48,8 @@ WorkerPool::WorkerPool( : num_workers_per_process_(num_workers_per_process), multiple_for_warning_(std::max(num_worker_processes, maximum_startup_concurrency)), maximum_startup_concurrency_(maximum_startup_concurrency), - gcs_client_(std::move(gcs_client)), - last_warning_multiple_(0) { + last_warning_multiple_(0), + gcs_client_(std::move(gcs_client)) { RAY_CHECK(num_workers_per_process > 0) << "num_workers_per_process must be positive."; RAY_CHECK(maximum_startup_concurrency > 0); // Ignore SIGCHLD signals. If we don't do this, then worker processes will diff --git a/src/ray/rpc/util.h b/src/ray/rpc/util.h index 6ecc6c3c4a341..59ae75ae33bee 100644 --- a/src/ray/rpc/util.h +++ b/src/ray/rpc/util.h @@ -1,6 +1,7 @@ #ifndef RAY_RPC_UTIL_H #define RAY_RPC_UTIL_H +#include #include #include "ray/common/status.h" @@ -27,6 +28,18 @@ inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) { } } +template +inline std::vector VectorFromProtobuf( + const ::google::protobuf::RepeatedPtrField &pb_repeated) { + return std::vector(pb_repeated.begin(), pb_repeated.end()); +} + +template +inline std::vector VectorFromProtobuf( + const ::google::protobuf::RepeatedField &pb_repeated) { + return std::vector(pb_repeated.begin(), pb_repeated.end()); +} + } // namespace rpc } // namespace ray