From 3c78c68fb4a594f53a5bef39f76b15191591b081 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Tue, 20 Aug 2019 20:45:01 +0800 Subject: [PATCH 01/24] Support direct actor call in Java worker --- .../ray/api/options/ActorCreationOptions.java | 20 ++++++++++++++----- java/test.sh | 3 +++ src/ray/core_worker/lib/java/jni_init.cc | 6 ++++++ src/ray/core_worker/lib/java/jni_utils.h | 4 ++++ ...rg_ray_runtime_task_NativeTaskSubmitter.cc | 9 ++++++++- 5 files changed, 36 insertions(+), 6 deletions(-) diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index 2e14ca8584ddc..22a9d7b301836 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -10,26 +10,31 @@ public class ActorCreationOptions extends BaseTaskOptions { public static final int NO_RECONSTRUCTION = 0; public static final int INFINITE_RECONSTRUCTIONS = (int) Math.pow(2, 30); + private static final boolean DEFAULT_IS_DIRECT_CALL = "1" + .equals(System.getenv("ACTOR_CREATION_OPTIONS_DEFAULT_IS_DIRECT_CALL")); public final int maxReconstructions; + public final boolean isDirectCall; + public final String jvmOptions; - private ActorCreationOptions(Map resources, - int maxReconstructions, - String jvmOptions) { + private ActorCreationOptions(Map resources, int maxReconstructions, + boolean isDirectCall, String jvmOptions) { super(resources); this.maxReconstructions = maxReconstructions; + this.isDirectCall = isDirectCall; this.jvmOptions = jvmOptions; } /** - * The inner class for building ActorCreationOptions. + * The inner class for building ActorCreationOptions. */ public static class Builder { private Map resources = new HashMap<>(); private int maxReconstructions = NO_RECONSTRUCTION; + private boolean isDirectCall = DEFAULT_IS_DIRECT_CALL; private String jvmOptions = ""; public Builder setResources(Map resources) { @@ -42,13 +47,18 @@ public Builder setMaxReconstructions(int maxReconstructions) { return this; } + public Builder setIsDirectCall(boolean isDirectCall) { + this.isDirectCall = isDirectCall; + return this; + } + public Builder setJvmOptions(String jvmOptions) { this.jvmOptions = jvmOptions; return this; } public ActorCreationOptions createActorCreationOptions() { - return new ActorCreationOptions(resources, maxReconstructions, jvmOptions); + return new ActorCreationOptions(resources, maxReconstructions, isDirectCall, jvmOptions); } } diff --git a/java/test.sh b/java/test.sh index ba728f14bf38a..bc74a4072ab75 100755 --- a/java/test.sh +++ b/java/test.sh @@ -27,6 +27,9 @@ echo "Running tests under cluster mode." # bazel test //java:all_tests --action_env=ENABLE_MULTI_LANGUAGE_TESTS=1 --test_output="errors" || cluster_exit_code=$? ENABLE_MULTI_LANGUAGE_TESTS=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml +echo "Running tests under cluster mode with direct actor call turned on." +ENABLE_MULTI_LANGUAGE_TESTS=1 ACTOR_CREATION_OPTIONS_DEFAULT_IS_DIRECT_CALL=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml + echo "Running tests under single-process mode." # bazel test //java:all_tests --jvmopt="-Dray.run-mode=SINGLE_PROCESS" --test_output="errors" || single_exit_code=$? run_testng java -Dray.run-mode="SINGLE_PROCESS" -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index a7bd918acfb54..6d63f92c87552 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -49,7 +49,9 @@ jclass java_base_task_options_class; jfieldID java_base_task_options_resources; jclass java_actor_creation_options_class; +jfieldID java_actor_creation_options_default_is_direct_call; jfieldID java_actor_creation_options_max_reconstructions; +jfieldID java_actor_creation_options_is_direct_call; jfieldID java_actor_creation_options_jvm_options; jclass java_gcs_client_options_class; @@ -145,8 +147,12 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_actor_creation_options_class = LoadClass(env, "org/ray/api/options/ActorCreationOptions"); + java_actor_creation_options_default_is_direct_call = env->GetStaticFieldID( + java_actor_creation_options_class, "DEFAULT_IS_DIRECT_CALL", "Z"); java_actor_creation_options_max_reconstructions = env->GetFieldID(java_actor_creation_options_class, "maxReconstructions", "I"); + java_actor_creation_options_is_direct_call = + env->GetFieldID(java_actor_creation_options_class, "isDirectCall", "Z"); java_actor_creation_options_jvm_options = env->GetFieldID( java_actor_creation_options_class, "jvmOptions", "Ljava/lang/String;"); diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index 396b5a8414174..396c4301d6637 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -91,8 +91,12 @@ extern jfieldID java_base_task_options_resources; /// ActorCreationOptions class extern jclass java_actor_creation_options_class; +/// DEFAULT_IS_DIRECT_CALL field of ActorCreationOptions class +extern jfieldID java_actor_creation_options_default_is_direct_call; /// maxReconstructions field of ActorCreationOptions class extern jfieldID java_actor_creation_options_max_reconstructions; +/// isDirectCall field of ActorCreationOptions class +extern jfieldID java_actor_creation_options_is_direct_call; /// jvmOptions field of ActorCreationOptions class extern jfieldID java_actor_creation_options_jvm_options; diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc index ca219ddae15df..dd5dc0610af3f 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc @@ -76,11 +76,14 @@ inline ray::TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject call inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, jobject actorCreationOptions) { uint64_t max_reconstructions = 0; + bool is_direct_call; std::unordered_map resources; std::vector dynamic_worker_options; if (actorCreationOptions) { max_reconstructions = static_cast(env->GetIntField( actorCreationOptions, java_actor_creation_options_max_reconstructions)); + is_direct_call = env->GetBooleanField(actorCreationOptions, + java_actor_creation_options_is_direct_call); jobject java_resources = env->GetObjectField(actorCreationOptions, java_base_task_options_resources); resources = ToResources(env, java_resources); @@ -88,10 +91,14 @@ inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, env, (jstring)env->GetObjectField(actorCreationOptions, java_actor_creation_options_jvm_options)); dynamic_worker_options.emplace_back(jvm_options); + } else { + is_direct_call = + env->GetStaticBooleanField(java_actor_creation_options_class, + java_actor_creation_options_default_is_direct_call); } ray::ActorCreationOptions action_creation_options{ - static_cast(max_reconstructions), false, resources, + static_cast(max_reconstructions), is_direct_call, resources, dynamic_worker_options}; return action_creation_options; } From 9530f162c6001a646afc882272c5a72a38752bf0 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Tue, 20 Aug 2019 22:22:31 +0800 Subject: [PATCH 02/24] Skip some tests --- .../java/org/ray/api/options/ActorCreationOptions.java | 2 +- java/test/src/main/java/org/ray/api/TestUtils.java | 7 +++++++ .../java/org/ray/api/test/ActorReconstructionTest.java | 4 ++++ java/test/src/main/java/org/ray/api/test/ActorTest.java | 2 ++ .../java/org/ray/api/test/CrossLanguageInvocationTest.java | 4 ++++ java/test/src/main/java/org/ray/api/test/FailureTest.java | 3 +++ java/test/src/main/java/org/ray/api/test/StressTest.java | 2 ++ 7 files changed, 23 insertions(+), 1 deletion(-) diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index 22a9d7b301836..69918edc3e34c 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -10,7 +10,7 @@ public class ActorCreationOptions extends BaseTaskOptions { public static final int NO_RECONSTRUCTION = 0; public static final int INFINITE_RECONSTRUCTIONS = (int) Math.pow(2, 30); - private static final boolean DEFAULT_IS_DIRECT_CALL = "1" + public static final boolean DEFAULT_IS_DIRECT_CALL = "1" .equals(System.getenv("ACTOR_CREATION_OPTIONS_DEFAULT_IS_DIRECT_CALL")); public final int maxReconstructions; diff --git a/java/test/src/main/java/org/ray/api/TestUtils.java b/java/test/src/main/java/org/ray/api/TestUtils.java index 3badb110445df..4e29f36dd3df6 100644 --- a/java/test/src/main/java/org/ray/api/TestUtils.java +++ b/java/test/src/main/java/org/ray/api/TestUtils.java @@ -2,6 +2,7 @@ import java.util.function.Supplier; import org.ray.api.annotation.RayRemote; +import org.ray.api.options.ActorCreationOptions; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.config.RunMode; import org.testng.Assert; @@ -18,6 +19,12 @@ public static void skipTestUnderSingleProcess() { } } + public static void skipTestIfDirectActorCallEnabled() { + if (ActorCreationOptions.DEFAULT_IS_DIRECT_CALL) { + throw new SkipException("This test doesn't work when direct actor call is enabled."); + } + } + /** * Wait until the given condition is met. * diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java index 3e50b4d966a2f..1d1d730d96c98 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java @@ -47,6 +47,8 @@ public int getPid() { @Test public void testActorReconstruction() throws InterruptedException, IOException { TestUtils.skipTestUnderSingleProcess(); + // No lineage cache when direct actor call is enabled. + TestUtils.skipTestIfDirectActorCallEnabled(); ActorCreationOptions options = new ActorCreationOptions.Builder().setMaxReconstructions(1).createActorCreationOptions(); RayActor actor = Ray.createActor(Counter::new, options); @@ -128,6 +130,8 @@ public void checkpointExpired(ActorId actorId, UniqueId checkpointId) { @Test public void testActorCheckpointing() throws IOException, InterruptedException { TestUtils.skipTestUnderSingleProcess(); + // Actor checkpointing is not implemented in direct actor call yet. + TestUtils.skipTestIfDirectActorCallEnabled(); ActorCreationOptions options = new ActorCreationOptions.Builder().setMaxReconstructions(1).createActorCreationOptions(); RayActor actor = Ray.createActor(CheckpointableCounter::new, options); diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index 784c82c4c92c7..7ee4bb12dade0 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -104,6 +104,8 @@ public void testForkingActorHandle() { @Test public void testUnreconstructableActorObject() throws InterruptedException { TestUtils.skipTestUnderSingleProcess(); + // The UnreconstructableException is created by raylet. + TestUtils.skipTestIfDirectActorCallEnabled(); RayActor counter = Ray.createActor(Counter::new, 100); // Call an actor method. RayObject value = Ray.call(Counter::getValue, counter); diff --git a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java index 2f75c7b54ff4d..f69c520fec157 100644 --- a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java +++ b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java @@ -9,6 +9,7 @@ import org.ray.api.Ray; import org.ray.api.RayObject; import org.ray.api.RayPyActor; +import org.ray.api.TestUtils; import org.testng.Assert; import org.testng.annotations.Test; @@ -47,6 +48,9 @@ public void testCallingPythonFunction() { @Test public void testCallingPythonActor() { + // Direct actor call only allows passing arguments as values. + // However, bytes arguments are passed from Java to Python as references. + TestUtils.skipTestIfDirectActorCallEnabled(); RayPyActor actor = Ray.createPyActor(PYTHON_MODULE, "Counter", "1".getBytes()); RayObject res = Ray.callPy(actor, "increase", "1".getBytes()); Assert.assertEquals(res.get(), "2".getBytes()); diff --git a/java/test/src/main/java/org/ray/api/test/FailureTest.java b/java/test/src/main/java/org/ray/api/test/FailureTest.java index b47b010ae25a8..2f23a58191932 100644 --- a/java/test/src/main/java/org/ray/api/test/FailureTest.java +++ b/java/test/src/main/java/org/ray/api/test/FailureTest.java @@ -105,6 +105,9 @@ public void testWorkerProcessDying() { @Test public void testActorProcessDying() { TestUtils.skipTestUnderSingleProcess(); + // If direct actor call is enabled, we can get a RayActorException only if the actor + // is already dead before submitting the task. + TestUtils.skipTestIfDirectActorCallEnabled(); RayActor actor = Ray.createActor(BadActor::new, false); try { Ray.call(BadActor::badMethod2, actor).get(); diff --git a/java/test/src/main/java/org/ray/api/test/StressTest.java b/java/test/src/main/java/org/ray/api/test/StressTest.java index e2efecbf222e1..a1dc5116e055d 100644 --- a/java/test/src/main/java/org/ray/api/test/StressTest.java +++ b/java/test/src/main/java/org/ray/api/test/StressTest.java @@ -75,6 +75,8 @@ public int ping(int n) { @Test public void testSubmittingManyTasksToOneActor() { TestUtils.skipTestUnderSingleProcess(); + // TODO (kfstorm): Don't know why it hangs. + TestUtils.skipTestIfDirectActorCallEnabled(); RayActor actor = Ray.createActor(Actor::new); List objectIds = new ArrayList<>(); for (int i = 0; i < 10; i++) { From 9fc1e2424aba1769f7adb488d6a72f2b73d55c56 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Thu, 22 Aug 2019 21:09:51 +0800 Subject: [PATCH 03/24] Support pass large object by value to a direct call actor. --- .../org/ray/runtime/AbstractRayRuntime.java | 14 ++++++++-- .../org/ray/runtime/actor/NativeRayActor.java | 6 ++++ .../ray/runtime/task/ArgumentsBuilder.java | 17 +++++++++-- .../src/main/java/org/ray/api/TestUtils.java | 5 ++++ .../main/java/org/ray/api/test/ActorTest.java | 28 +++++++++++++++++++ .../java/org/ray/api/test/RayCallTest.java | 6 +--- .../org_ray_runtime_actor_NativeRayActor.cc | 10 +++++++ .../org_ray_runtime_actor_NativeRayActor.h | 8 ++++++ .../transport/direct_actor_transport.cc | 4 +++ 9 files changed, 88 insertions(+), 10 deletions(-) diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 55d5fc5f9267e..6d442c4785344 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -15,6 +15,7 @@ import org.ray.api.options.CallOptions; import org.ray.api.runtime.RayRuntime; import org.ray.api.runtimecontext.RuntimeContext; +import org.ray.runtime.actor.NativeRayActor; import org.ray.runtime.config.RayConfig; import org.ray.runtime.context.RuntimeContextImpl; import org.ray.runtime.context.WorkerContext; @@ -166,7 +167,7 @@ public RayPyActor createPyActor(String moduleName, String className, Object[] ar private RayObject callNormalFunction(FunctionDescriptor functionDescriptor, Object[] args, CallOptions options) { List functionArgs = ArgumentsBuilder - .wrap(args, functionDescriptor.getLanguage() != Language.JAVA); + .wrap(args, functionDescriptor.getLanguage() != Language.JAVA, /*isDirectActorCall*/false); List returnIds = taskSubmitter.submitTask(functionDescriptor, functionArgs, 1, options); return new RayObjectImpl(returnIds.get(0)); @@ -175,7 +176,7 @@ private RayObject callNormalFunction(FunctionDescriptor functionDescriptor, private RayObject callActorFunction(RayActor rayActor, FunctionDescriptor functionDescriptor, Object[] args) { List functionArgs = ArgumentsBuilder - .wrap(args, functionDescriptor.getLanguage() != Language.JAVA); + .wrap(args, functionDescriptor.getLanguage() != Language.JAVA, isDirectActorCall(rayActor)); List returnIds = taskSubmitter.submitActorTask(rayActor, functionDescriptor, functionArgs, 1, null); return new RayObjectImpl(returnIds.get(0)); @@ -184,7 +185,7 @@ private RayObject callActorFunction(RayActor rayActor, private RayActor createActorImpl(FunctionDescriptor functionDescriptor, Object[] args, ActorCreationOptions options) { List functionArgs = ArgumentsBuilder - .wrap(args, functionDescriptor.getLanguage() != Language.JAVA); + .wrap(args, functionDescriptor.getLanguage() != Language.JAVA, /*isDirectActorCall*/false); if (functionDescriptor.getLanguage() != Language.JAVA && options != null) { Preconditions.checkState(StringUtil.isNullOrEmpty(options.jvmOptions)); } @@ -194,6 +195,13 @@ private RayActor createActorImpl(FunctionDescriptor functionDescriptor, return actor; } + private boolean isDirectActorCall(RayActor rayActor) { + if (rayActor instanceof NativeRayActor) { + return ((NativeRayActor) rayActor).isDirectCall(); + } + return false; + } + public WorkerContext getWorkerContext() { return workerContext; } diff --git a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java index ecdf030535e18..cbe13c092a411 100644 --- a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java +++ b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java @@ -51,6 +51,10 @@ public Language getLanguage() { return Language.forNumber(nativeGetLanguage(nativeActorHandle)); } + public boolean isDirectCall() { + return nativeIsDirectCall(nativeActorHandle); + } + @Override public String getModuleName() { Preconditions.checkState(getLanguage() == Language.PYTHON); @@ -90,6 +94,8 @@ protected void finalize() { private static native int nativeGetLanguage(long nativeActorHandle); + private static native boolean nativeIsDirectCall(long nativeActorHandle); + private static native List nativeGetActorCreationTaskFunctionDescriptor( long nativeActorHandle); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index 110c178f7a10a..175a99d30a783 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -24,7 +24,8 @@ public class ArgumentsBuilder { /** * Convert real function arguments to task spec arguments. */ - public static List wrap(Object[] args, boolean crossLanguage) { + public static List wrap(Object[] args, boolean crossLanguage, + boolean isDirectActorCall) { List ret = new ArrayList<>(); for (Object arg : args) { ObjectId id = null; @@ -32,15 +33,20 @@ public static List wrap(Object[] args, boolean crossLanguage) { if (arg == null) { data = Serializer.encode(null); } else if (arg instanceof RayObject) { + throwExceptionIfIsDirectActorCall(isDirectActorCall, + "Passing RayObject to a direct call actor is not supported."); id = ((RayObject) arg).getId(); } else if (arg instanceof byte[] && crossLanguage) { + // TODO (kfstorm): This could be supported once we supported passing by value with metadata. + throwExceptionIfIsDirectActorCall(isDirectActorCall, + "Passing raw bytes to a direct call actor is not supported."); // If the argument is a byte array and will be used by a different language, // do not inline this argument. Because the other language doesn't know how // to deserialize it. id = Ray.put(arg).getId(); } else { byte[] serialized = Serializer.encode(arg); - if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) { + if (!isDirectActorCall && serialized.length > LARGEST_SIZE_PASS_BY_VALUE) { id = ((AbstractRayRuntime) Ray.internal()).getObjectStore() .put(new NativeRayObject(serialized, null)); } else { @@ -56,6 +62,13 @@ public static List wrap(Object[] args, boolean crossLanguage) { return ret; } + private static void throwExceptionIfIsDirectActorCall(boolean isDirectActorCall, String message) { + if (isDirectActorCall) { + throw new IllegalArgumentException( + message != null ? message : "Direct actor call only supports by-value arguments."); + } + } + /** * Convert list of NativeRayObject to real function arguments. */ diff --git a/java/test/src/main/java/org/ray/api/TestUtils.java b/java/test/src/main/java/org/ray/api/TestUtils.java index 4e29f36dd3df6..8a68833915264 100644 --- a/java/test/src/main/java/org/ray/api/TestUtils.java +++ b/java/test/src/main/java/org/ray/api/TestUtils.java @@ -1,5 +1,6 @@ package org.ray.api; +import java.io.Serializable; import java.util.function.Supplier; import org.ray.api.annotation.RayRemote; import org.ray.api.options.ActorCreationOptions; @@ -10,6 +11,10 @@ public class TestUtils { + public static class LargeObject implements Serializable { + public byte[] data = new byte[1024 * 1024]; + } + private static final int WAIT_INTERVAL_MS = 5; public static void skipTestUnderSingleProcess() { diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index 7ee4bb12dade0..02abd738bc4ba 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -8,6 +8,7 @@ import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.TestUtils; +import org.ray.api.TestUtils.LargeObject; import org.ray.api.annotation.RayRemote; import org.ray.api.exception.UnreconstructableException; import org.ray.api.id.UniqueId; @@ -36,6 +37,25 @@ public int increase(int delta) { value += delta; return value; } + + public int accessLargeObject(LargeObject largeObject) { + value += largeObject.data.length; + return value; + } + } + + @RayRemote + public static class Caller { + + private final RayActor counter; + + public Caller(RayActor counter) { + this.counter = counter; + } + + public int call() { + return Ray.call(Counter::increase, counter, 1).get(); + } } @Test @@ -48,6 +68,14 @@ public void testCreateAndCallActor() { Assert.assertEquals(Integer.valueOf(11), Ray.call(Counter::increase, actor, 10).get()); } + @Test + public void testCallActorWithLargeObject() { + RayActor actor = Ray.createActor(Counter::new, 1); + LargeObject largeObject = new LargeObject(); + Assert.assertEquals(Integer.valueOf(largeObject.data.length + 1), + Ray.call(Counter::accessLargeObject, actor, largeObject).get()); + } + @RayRemote public static Counter factory(int initValue) { return new Counter(initValue); diff --git a/java/test/src/main/java/org/ray/api/test/RayCallTest.java b/java/test/src/main/java/org/ray/api/test/RayCallTest.java index c97e3fe915244..8f8318f932b55 100644 --- a/java/test/src/main/java/org/ray/api/test/RayCallTest.java +++ b/java/test/src/main/java/org/ray/api/test/RayCallTest.java @@ -2,10 +2,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import java.io.Serializable; import java.util.List; import java.util.Map; import org.ray.api.Ray; +import org.ray.api.TestUtils.LargeObject; import org.ray.api.annotation.RayRemote; import org.testng.Assert; import org.testng.annotations.Test; @@ -65,10 +65,6 @@ private static Map testMap(Map val) { return val; } - public static class LargeObject implements Serializable { - private byte[] data = new byte[1024 * 1024]; - } - @RayRemote private static LargeObject testLargeObject(LargeObject largeObject) { return largeObject; diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc index a63e7efa00162..f632f280f75b4 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc @@ -57,6 +57,16 @@ JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLangua return (jint)GetActorHandle(nativeActorHandle).ActorLanguage(); } +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeIsDirectCall + * Signature: (J)Z + */ +JNIEXPORT jboolean JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeIsDirectCall( + JNIEnv *env, jclass o, jlong nativeActorHandle) { + return GetActorHandle(nativeActorHandle).IsDirectCallActor(); +} + /* * Class: org_ray_runtime_actor_NativeRayActor * Method: nativeGetActorCreationTaskFunctionDescriptor diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h index 4de114c7a8b4d..92363773c8d70 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h @@ -40,6 +40,14 @@ Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorHandleId(JNIEnv *, jclas JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage(JNIEnv *, jclass, jlong); +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeIsDirectCall + * Signature: (J)Z + */ +JNIEXPORT jboolean JNICALL +Java_org_ray_runtime_actor_NativeRayActor_nativeIsDirectCall(JNIEnv *, jclass, jlong); + /* * Class: org_ray_runtime_actor_NativeRayActor * Method: nativeGetActorCreationTaskFunctionDescriptor diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 2f5733d3cf08f..65916da62764b 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -28,6 +28,7 @@ CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter( Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( const TaskSpecification &task_spec) { + RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId(); if (HasByReferenceArgs(task_spec)) { return Status::Invalid("direct actor call only supports by-value arguments"); } @@ -52,6 +53,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( // to have a timeout to mark it as invalid if it doesn't show up in the // specified time. pending_requests_[actor_id].emplace_back(std::move(request)); + RAY_LOG(DEBUG) << "Actor " << actor_id << " is not yet created."; return Status::OK(); } else if (iter->second.state_ == ActorTableData::ALIVE) { // Actor is alive, submit the request. @@ -125,6 +127,7 @@ Status CoreWorkerDirectActorTaskSubmitter::PushTask(rpc::DirectActorClient &clie const rpc::PushTaskRequest &request, const TaskID &task_id, int num_returns) { + RAY_LOG(DEBUG) << "Push task " << task_id; auto status = client.PushTask( request, [this, task_id, num_returns](Status status, const rpc::PushTaskReply &reply) { @@ -188,6 +191,7 @@ void CoreWorkerDirectActorTaskReceiver::HandlePushTask( const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { const TaskSpecification task_spec(request.task_spec()); + RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId(); if (HasByReferenceArgs(task_spec)) { send_reply_callback( Status::Invalid("direct actor call only supports by value arguments"), nullptr, From 0c474b9ebf7807a572583887c8919e9939a57a79 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Fri, 23 Aug 2019 20:06:58 +0800 Subject: [PATCH 04/24] Support actor checkpointing for direct call --- .../context/LocalModeWorkerContext.java | 5 +++++ .../runtime/context/NativeWorkerContext.java | 7 +++++++ .../ray/runtime/context/WorkerContext.java | 5 +++++ .../runtime/raylet/LocalModeRayletClient.java | 3 ++- .../runtime/raylet/NativeRayletClient.java | 8 ++++--- .../org/ray/runtime/raylet/RayletClient.java | 2 +- .../org/ray/runtime/task/TaskExecutor.java | 7 ++++--- .../ray/api/test/ActorReconstructionTest.java | 2 -- python/ray/_raylet.pyx | 2 +- python/ray/includes/libraylet.pxd | 1 + python/ray/includes/task.pxd | 3 ++- python/ray/includes/task.pxi | 1 + src/ray/common/task/task_spec.cc | 8 ++++++- src/ray/common/task/task_spec.h | 2 ++ src/ray/common/task/task_util.h | 13 +++++++----- src/ray/core_worker/context.cc | 5 +++++ src/ray/core_worker/context.h | 2 ++ ...ray_runtime_context_NativeWorkerContext.cc | 11 ++++++++++ ..._ray_runtime_context_NativeWorkerContext.h | 9 ++++++++ ...g_ray_runtime_raylet_NativeRayletClient.cc | 7 ++++--- ...rg_ray_runtime_raylet_NativeRayletClient.h | 5 +++-- src/ray/core_worker/task_interface.cc | 3 ++- src/ray/core_worker/test/core_worker_test.cc | 3 ++- src/ray/protobuf/common.proto | 2 ++ src/ray/raylet/actor_registration.cc | 21 +++++++++++-------- src/ray/raylet/actor_registration.h | 5 +++-- src/ray/raylet/format/node_manager.fbs | 2 ++ src/ray/raylet/node_manager.cc | 19 +++++++++++------ src/ray/raylet/raylet_client.cc | 5 +++-- src/ray/raylet/raylet_client.h | 3 ++- 30 files changed, 126 insertions(+), 45 deletions(-) diff --git a/java/runtime/src/main/java/org/ray/runtime/context/LocalModeWorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/context/LocalModeWorkerContext.java index 1f05c3d5989b7..06b1d578b8218 100644 --- a/java/runtime/src/main/java/org/ray/runtime/context/LocalModeWorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/context/LocalModeWorkerContext.java @@ -41,6 +41,11 @@ public ActorId getCurrentActorId() { return LocalModeTaskSubmitter.getActorId(taskSpec); } + @Override + public boolean getIsDirectCall() { + return false; + } + @Override public ClassLoader getCurrentClassLoader() { return null; diff --git a/java/runtime/src/main/java/org/ray/runtime/context/NativeWorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/context/NativeWorkerContext.java index b42a7b23411db..62188a952ca97 100644 --- a/java/runtime/src/main/java/org/ray/runtime/context/NativeWorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/context/NativeWorkerContext.java @@ -38,6 +38,11 @@ public ActorId getCurrentActorId() { return ActorId.fromByteBuffer(nativeGetCurrentActorId(nativeCoreWorkerPointer)); } + @Override + public boolean getIsDirectCall() { + return nativeGetIsDirectCall(nativeCoreWorkerPointer); + } + @Override public ClassLoader getCurrentClassLoader() { return currentClassLoader; @@ -69,4 +74,6 @@ public TaskId getCurrentTaskId() { private static native ByteBuffer nativeGetCurrentWorkerId(long nativeCoreWorkerPointer); private static native ByteBuffer nativeGetCurrentActorId(long nativeCoreWorkerPointer); + + private static native boolean nativeGetIsDirectCall(long nativeCoreWorkerPointer); } diff --git a/java/runtime/src/main/java/org/ray/runtime/context/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/context/WorkerContext.java index 4a526c85ecbfb..23894e0548db4 100644 --- a/java/runtime/src/main/java/org/ray/runtime/context/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/context/WorkerContext.java @@ -26,6 +26,11 @@ public interface WorkerContext { */ ActorId getCurrentActorId(); + /** + * Whether the current task is a direct call task. + */ + boolean getIsDirectCall(); + /** * The class loader that is associated with the current job. It's used for locating classes when * dealing with serialization and deserialization in {@link org.ray.runtime.util.Serializer}. diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/LocalModeRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/LocalModeRayletClient.java index 9d43244c35db5..a9976893095a4 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/LocalModeRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/LocalModeRayletClient.java @@ -10,10 +10,11 @@ * Raylet client for local mode. */ public class LocalModeRayletClient implements RayletClient { + private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeRayletClient.class); @Override - public UniqueId prepareCheckpoint(ActorId actorId) { + public UniqueId prepareCheckpoint(ActorId actorId, boolean isDirectCall) { throw new NotImplementedException("Not implemented."); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/NativeRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/NativeRayletClient.java index ed5f10f128ce5..e423aa8301306 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/NativeRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/NativeRayletClient.java @@ -19,8 +19,9 @@ public NativeRayletClient(long nativeCoreWorkerPointer) { } @Override - public UniqueId prepareCheckpoint(ActorId actorId) { - return new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer, actorId.getBytes())); + public UniqueId prepareCheckpoint(ActorId actorId, boolean isDirectCall) { + return new UniqueId( + nativePrepareCheckpoint(nativeCoreWorkerPointer, actorId.getBytes(), isDirectCall)); } @Override @@ -47,7 +48,8 @@ public void setResource(String resourceName, double capacity, UniqueId nodeId) { /// 5) vim $Dir/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc /// 6) popd - private static native byte[] nativePrepareCheckpoint(long conn, byte[] actorId); + private static native byte[] nativePrepareCheckpoint(long conn, byte[] actorId, + boolean isDirectCall); private static native void nativeNotifyActorResumedFromCheckpoint(long conn, byte[] actorId, byte[] checkpointId); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index 144187b6b83a4..12e2bc712bd99 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -8,7 +8,7 @@ */ public interface RayletClient { - UniqueId prepareCheckpoint(ActorId actorId); + UniqueId prepareCheckpoint(ActorId actorId, boolean isDirectCall); void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java index 2f595bd2e1d7a..2997753f8e7d6 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java @@ -98,7 +98,8 @@ protected List execute(List rayFunctionInfo, if (taskType != TaskType.ACTOR_CREATION_TASK) { if (taskType == TaskType.ACTOR_TASK) { // TODO (kfstorm): handle checkpoint in core worker. - maybeSaveCheckpoint(actor, runtime.getWorkerContext().getCurrentActorId()); + maybeSaveCheckpoint(actor, runtime.getWorkerContext().getCurrentActorId(), + runtime.getWorkerContext().getIsDirectCall()); } returnObjects.add(runtime.getObjectStore().serialize(result)); } else { @@ -128,7 +129,7 @@ private JavaFunctionDescriptor parseFunctionDescriptor(List rayFunctionI rayFunctionInfo.get(2)); } - private void maybeSaveCheckpoint(Object actor, ActorId actorId) { + private void maybeSaveCheckpoint(Object actor, ActorId actorId, boolean isDirectCall) { if (!(actor instanceof Checkpointable)) { return; } @@ -144,7 +145,7 @@ private void maybeSaveCheckpoint(Object actor, ActorId actorId) { } numTasksSinceLastCheckpoint = 0; lastCheckpointTimestamp = System.currentTimeMillis(); - UniqueId checkpointId = runtime.getRayletClient().prepareCheckpoint(actorId); + UniqueId checkpointId = runtime.getRayletClient().prepareCheckpoint(actorId, isDirectCall); checkpointIds.add(checkpointId); if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) { ((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0)); diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java index 1d1d730d96c98..94e7a5cd8b44a 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java @@ -130,8 +130,6 @@ public void checkpointExpired(ActorId actorId, UniqueId checkpointId) { @Test public void testActorCheckpointing() throws IOException, InterruptedException { TestUtils.skipTestUnderSingleProcess(); - // Actor checkpointing is not implemented in direct actor call yet. - TestUtils.skipTestIfDirectActorCallEnabled(); ActorCreationOptions options = new ActorCreationOptions.Builder().setMaxReconstructions(1).createActorCreationOptions(); RayActor actor = Ray.createActor(CheckpointableCounter::new, options); diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index a1e9387fe86fe..47b90942b2dc0 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -357,7 +357,7 @@ cdef class RayletClient: # the GIL so other Python threads can run. with nogil: check_status(self.client.get().PrepareActorCheckpoint( - c_actor_id, checkpoint_id)) + c_actor_id, False, checkpoint_id)) return ActorCheckpointID(checkpoint_id.Binary()) def notify_actor_resumed_from_checkpoint(self, ActorID actor_id, diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd index 45746da2fdf43..e5dcf6efe1689 100644 --- a/python/ray/includes/libraylet.pxd +++ b/python/ray/includes/libraylet.pxd @@ -68,6 +68,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: CRayStatus FreeObjects(const c_vector[CObjectID] &object_ids, c_bool local_only, c_bool delete_creating_tasks) CRayStatus PrepareActorCheckpoint(const CActorID &actor_id, + c_bool is_direct_call, CActorCheckpointID &checkpoint_id) CRayStatus NotifyActorResumedFromCheckpoint( const CActorID &actor_id, const CActorCheckpointID &checkpoint_id) diff --git a/python/ray/includes/task.pxd b/python/ray/includes/task.pxd index 00b45d02baf87..40baabe9f0b8c 100644 --- a/python/ray/includes/task.pxd +++ b/python/ray/includes/task.pxd @@ -97,7 +97,8 @@ cdef extern from "ray/common/task/task_util.h" namespace "ray" nogil: const CObjectID &actor_creation_dummy_object_id, const CObjectID &previous_actor_task_dummy_object_id, uint64_t actor_counter, - const c_vector[CActorHandleID] &new_handle_ids); + const c_vector[CActorHandleID] &new_handle_ids, + c_bool is_direct_call); RpcTaskSpec GetMessage() diff --git a/python/ray/includes/task.pxi b/python/ray/includes/task.pxi index f1290ac7d7273..d052a791026ee 100644 --- a/python/ray/includes/task.pxi +++ b/python/ray/includes/task.pxi @@ -92,6 +92,7 @@ cdef class TaskSpec: previous_actor_task_dummy_object_id.native(), actor_counter, c_new_actor_handles, + False, ) else: # Normal task. diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index 1f04d394329e3..d6172810c82f7 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -146,6 +146,11 @@ std::vector TaskSpecification::NewActorHandles() const { message_->actor_task_spec().new_actor_handles()); } +bool TaskSpecification::IsDirectCall() const { + RAY_CHECK(IsActorTask()); + return message_->actor_task_spec().is_direct_call(); +} + std::string TaskSpecification::DebugString() const { std::ostringstream stream; stream << "Type=" << TaskType_Name(message_->type()) @@ -174,7 +179,8 @@ std::string TaskSpecification::DebugString() const { // Print actor task spec. stream << ", actor_task_spec={actor_id=" << ActorId() << ", actor_handle_id=" << ActorHandleId() - << ", actor_counter=" << ActorCounter() << "}"; + << ", actor_counter=" << ActorCounter() + << ", is_direct_call=" << IsDirectCall() << "}"; } return stream.str(); diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index d1fae90ac9e5b..e8704dc47bcf4 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -127,6 +127,8 @@ class TaskSpecification : public MessageWrapper { std::vector NewActorHandles() const; + bool IsDirectCall() const; + ObjectID ActorDummyObject() const; std::string DebugString() const; diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index 2bc635cc0313e..a343e8da97aec 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -93,11 +93,13 @@ class TaskSpecBuilder { /// See `common.proto` for meaning of the arguments. /// /// \return Reference to the builder object itself. - TaskSpecBuilder &SetActorTaskSpec( - const ActorID &actor_id, const ActorHandleID &actor_handle_id, - const ObjectID &actor_creation_dummy_object_id, - const ObjectID &previous_actor_task_dummy_object_id, uint64_t actor_counter, - const std::vector &new_handle_ids = {}) { + TaskSpecBuilder &SetActorTaskSpec(const ActorID &actor_id, + const ActorHandleID &actor_handle_id, + const ObjectID &actor_creation_dummy_object_id, + const ObjectID &previous_actor_task_dummy_object_id, + uint64_t actor_counter, + const std::vector &new_handle_ids = {}, + bool is_direct_call = false) { message_->set_type(TaskType::ACTOR_TASK); auto actor_spec = message_->mutable_actor_task_spec(); actor_spec->set_actor_id(actor_id.Binary()); @@ -110,6 +112,7 @@ class TaskSpecBuilder { for (const auto &id : new_handle_ids) { actor_spec->add_new_actor_handles(id.Binary()); } + actor_spec->set_is_direct_call(is_direct_call); return *this; } diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 896f8a70e91b1..ac63797f9d048 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -100,6 +100,11 @@ const ActorID &WorkerContext::GetCurrentActorID() const { return GetThreadContext().GetCurrentActorID(); } +bool WorkerContext::IsDirectCall() const { + std::shared_ptr task = GetThreadContext().GetCurrentTask(); + return task && task->IsActorTask() && task->IsDirectCall(); +} + WorkerThreadContext &WorkerContext::GetThreadContext() { if (thread_context_ == nullptr) { thread_context_ = std::unique_ptr(new WorkerThreadContext()); diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 77e9e28141450..54225cb109783 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -26,6 +26,8 @@ class WorkerContext { const ActorID &GetCurrentActorID() const; + bool IsDirectCall() const; + int GetNextTaskIndex(); int GetNextPutIndex(); diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.cc b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.cc index b7e7910446cd0..cc64478e75d15 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.cc @@ -78,6 +78,17 @@ Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId( return IdToJavaByteBuffer(env, actor_id); } +/* + * Class: org_ray_runtime_context_NativeWorkerContext + * Method: nativeGetIsDirectCall + * Signature: (J)Z + */ +JNIEXPORT jboolean JNICALL +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetIsDirectCall( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) { + return GetWorkerContextFromPointer(nativeCoreWorkerPointer).IsDirectCall(); +} + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h index fe3725484a0ea..b18060f0d772a 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h @@ -52,6 +52,15 @@ JNIEXPORT jobject JNICALL Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(JNIEnv *, jclass, jlong); +/* + * Class: org_ray_runtime_context_NativeWorkerContext + * Method: nativeGetIsDirectCall + * Signature: (J)Z + */ +JNIEXPORT jboolean JNICALL +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetIsDirectCall(JNIEnv *, jclass, + jlong); + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc b/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc index e84e4c51e1498..ddb907126ce99 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc @@ -19,15 +19,16 @@ using ray::ClientID; /* * Class: org_ray_runtime_raylet_NativeRayletClient * Method: nativePrepareCheckpoint - * Signature: (J[B)[B + * Signature: (J[BZ)[B */ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_NativeRayletClient_nativePrepareCheckpoint( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray actorId) { + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray actorId, + jboolean isDirectCall) { const auto actor_id = JavaByteArrayToId(env, actorId); ActorCheckpointID checkpoint_id; auto status = GetRayletClientFromPointer(nativeCoreWorkerPointer) - .PrepareActorCheckpoint(actor_id, checkpoint_id); + .PrepareActorCheckpoint(actor_id, isDirectCall, checkpoint_id); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); jbyteArray result = env->NewByteArray(checkpoint_id.Size()); env->SetByteArrayRegion(result, 0, checkpoint_id.Size(), diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h b/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h index 0b54300def8e4..9f7f67ef9f282 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h @@ -10,11 +10,12 @@ extern "C" { /* * Class: org_ray_runtime_raylet_NativeRayletClient * Method: nativePrepareCheckpoint - * Signature: (J[B)[B + * Signature: (J[BZ)[B */ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_NativeRayletClient_nativePrepareCheckpoint(JNIEnv *, jclass, - jlong, jbyteArray); + jlong, jbyteArray, + jboolean); /* * Class: org_ray_runtime_raylet_NativeRayletClient diff --git a/src/ray/core_worker/task_interface.cc b/src/ray/core_worker/task_interface.cc index de880f83eeff4..be588178164ce 100644 --- a/src/ray/core_worker/task_interface.cc +++ b/src/ray/core_worker/task_interface.cc @@ -217,7 +217,8 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, actor_handle.ActorID(), actor_handle.ActorHandleID(), actor_creation_dummy_object_id, /*previous_actor_task_dummy_object_id=*/actor_handle.ActorCursor(), - actor_handle.IncreaseTaskCounter(), actor_handle.NewActorHandles()); + actor_handle.IncreaseTaskCounter(), actor_handle.NewActorHandles(), + actor_handle.IsDirectCallActor()); // Manipulate actor handle state. auto actor_cursor = (*return_ids).back(); diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 99a7f5ca83b0d..8b976ed2282bf 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -652,7 +652,8 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { builder.SetActorTaskSpec( actor_handle.ActorID(), actor_handle.ActorHandleID(), actor_creation_dummy_object_id, - /*previous_actor_task_dummy_object_id=*/actor_handle.ActorCursor(), 0, {}); + /*previous_actor_task_dummy_object_id=*/actor_handle.ActorCursor(), 0, {}, + actor_handle.IsDirectCallActor()); const auto &task_spec = builder.Build(); diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index badcbea5b1ed2..df95593f7c6e2 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -106,6 +106,8 @@ message ActorTaskSpec { repeated bytes new_actor_handles = 6; // The dummy object ID of the previous actor task. bytes previous_actor_task_dummy_object_id = 7; + // Whether direct actor call is used. + bool is_direct_call = 8; } // The task execution specification encapsulates all mutable information about diff --git a/src/ray/raylet/actor_registration.cc b/src/ray/raylet/actor_registration.cc index 7381d8d13a5a5..d2d956159163d 100644 --- a/src/ray/raylet/actor_registration.cc +++ b/src/ray/raylet/actor_registration.cc @@ -98,16 +98,19 @@ void ActorRegistration::AddHandle(const ActorHandleID &handle_id, int ActorRegistration::NumHandles() const { return frontier_.size(); } std::shared_ptr ActorRegistration::GenerateCheckpointData( - const ActorID &actor_id, const Task &task) { - const auto actor_handle_id = task.GetTaskSpecification().ActorHandleId(); - const auto dummy_object = task.GetTaskSpecification().ActorDummyObject(); - // Make a copy of the actor registration, and extend its frontier to include - // the most recent task. - // Note(hchen): this is needed because this method is called before - // `FinishAssignedTask`, which will be called when the worker tries to fetch - // the next task. + const ActorID &actor_id, const Task *task) { + // Make a copy of the actor registration ActorRegistration copy = *this; - copy.ExtendFrontier(actor_handle_id, dummy_object); + if (task) { + const auto actor_handle_id = task->GetTaskSpecification().ActorHandleId(); + const auto dummy_object = task->GetTaskSpecification().ActorDummyObject(); + // Extend its frontier to include + // the most recent task. + // Note(hchen): this is needed because this method is called before + // `FinishAssignedTask`, which will be called when the worker tries to fetch + // the next task. + copy.ExtendFrontier(actor_handle_id, dummy_object); + } // Use actor's current state to generate checkpoint data. auto checkpoint_data = std::make_shared(); diff --git a/src/ray/raylet/actor_registration.h b/src/ray/raylet/actor_registration.h index 67bd394e85a93..8aa40253bab06 100644 --- a/src/ray/raylet/actor_registration.h +++ b/src/ray/raylet/actor_registration.h @@ -133,10 +133,11 @@ class ActorRegistration { /// Generate checkpoint data based on actor's current state. /// /// \param actor_id ID of this actor. - /// \param task The task that just finished on the actor. + /// \param task The task that just finished on the actor. (nullptr when it's direct + /// call.) /// \return A shared pointer to the generated checkpoint data. std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, - const Task &task); + const Task *task); private: /// Information from the global actor table about this actor, including the diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 705a9fdba9dd3..60e20a5aea6e9 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -219,6 +219,8 @@ table FreeObjectsRequest { table PrepareActorCheckpointRequest { // ID of the actor. actor_id: string; + // Whether direct actor call is used. + is_direct_call: bool; } table PrepareActorCheckpointReply { diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 53b84d77da4c7..7a66e31b68ef5 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1155,6 +1155,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( auto message = flatbuffers::GetRoot(message_data); ActorID actor_id = from_flatbuf(*message->actor_id()); + bool is_direct_call = message->is_direct_call(); RAY_LOG(DEBUG) << "Preparing checkpoint for actor " << actor_id; const auto &actor_entry = actor_registry_.find(actor_id); RAY_CHECK(actor_entry != actor_registry_.end()); @@ -1162,13 +1163,19 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); RAY_CHECK(worker && worker->GetActorId() == actor_id); - // Find the task that is running on this actor. - const auto task_id = worker->GetAssignedTaskId(); - const Task &task = local_queues_.GetTaskOfState(task_id, TaskState::RUNNING); - // Generate checkpoint id and data. ActorCheckpointID checkpoint_id = ActorCheckpointID::FromRandom(); - auto checkpoint_data = - actor_entry->second.GenerateCheckpointData(actor_entry->first, task); + std::shared_ptr checkpoint_data; + if (is_direct_call) { + checkpoint_data = + actor_entry->second.GenerateCheckpointData(actor_entry->first, nullptr); + } else { + // Find the task that is running on this actor. + const auto task_id = worker->GetAssignedTaskId(); + const Task &task = local_queues_.GetTaskOfState(task_id, TaskState::RUNNING); + // Generate checkpoint data. + checkpoint_data = + actor_entry->second.GenerateCheckpointData(actor_entry->first, &task); + } // Write checkpoint data to GCS. RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Add( diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 1c8871bf0bd08..be1db5297023b 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -354,10 +354,11 @@ ray::Status RayletClient::FreeObjects(const std::vector &object_i } ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id, + bool is_direct_call, ActorCheckpointID &checkpoint_id) { flatbuffers::FlatBufferBuilder fbb; - auto message = - ray::protocol::CreatePrepareActorCheckpointRequest(fbb, to_flatbuf(fbb, actor_id)); + auto message = ray::protocol::CreatePrepareActorCheckpointRequest( + fbb, to_flatbuf(fbb, actor_id), is_direct_call); fbb.Finish(message); std::unique_ptr reply; diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 235ba9cfb890b..793e2ea5a57ab 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -155,9 +155,10 @@ class RayletClient { /// Request raylet backend to prepare a checkpoint for an actor. /// /// \param actor_id ID of the actor. + /// \param is_direct_call Whether direct actor call is used. /// \param checkpoint_id ID of the new checkpoint (output parameter). /// \return ray::Status. - ray::Status PrepareActorCheckpoint(const ActorID &actor_id, + ray::Status PrepareActorCheckpoint(const ActorID &actor_id, bool is_direct_call, ActorCheckpointID &checkpoint_id); /// Notify raylet backend that an actor was resumed from a checkpoint. From 7a5801bef4bff259226b067dbe540c15744bf9d7 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Fri, 23 Aug 2019 22:43:52 +0800 Subject: [PATCH 05/24] Fix test case `testActorProcessDying` --- .../java/org/ray/api/test/FailureTest.java | 3 -- .../transport/direct_actor_transport.cc | 35 +++++++++++++++---- .../transport/direct_actor_transport.h | 6 +++- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/java/test/src/main/java/org/ray/api/test/FailureTest.java b/java/test/src/main/java/org/ray/api/test/FailureTest.java index 2f23a58191932..b47b010ae25a8 100644 --- a/java/test/src/main/java/org/ray/api/test/FailureTest.java +++ b/java/test/src/main/java/org/ray/api/test/FailureTest.java @@ -105,9 +105,6 @@ public void testWorkerProcessDying() { @Test public void testActorProcessDying() { TestUtils.skipTestUnderSingleProcess(); - // If direct actor call is enabled, we can get a RayActorException only if the actor - // is already dead before submitting the task. - TestUtils.skipTestIfDirectActorCallEnabled(); RayActor actor = Ray.createActor(BadActor::new, false); try { Ray.call(BadActor::badMethod2, actor).get(); diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 5f9e670318f01..62f0489505496 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -64,13 +64,14 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( // Submit request. auto &client = rpc_clients_[actor_id]; - PushTask(*client, *request, task_id, num_returns); + PushTask(*client, *request, actor_id, task_id, num_returns); return Status::OK(); } else { // Actor is dead, treat the task as failure. RAY_CHECK(iter->second.state_ == ActorTableData::DEAD); TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED); - return Status::IOError("Actor is dead."); + // Return OK here so that we can get the error from store with get operation. + return Status::OK(); } } @@ -94,6 +95,19 @@ Status CoreWorkerDirectActorTaskSubmitter::SubscribeActorUpdates() { } else { // Remove rpc client if it's dead or being reconstructed. rpc_clients_.erase(actor_id); + + // For tasks that have been sent and are waiting for replies, treat them + // as failed when the destination actor is dead or reconstructing. + auto iter = waiting_reply_tasks_.find(actor_id); + if (iter != waiting_reply_tasks_.end()) { + for (const auto &entry : iter->second) { + const auto &task_id = entry.first; + const auto num_returns = entry.second; + TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED); + } + waiting_reply_tasks_.erase(actor_id); + } + // If this actor is permanently dead and there are pending requests, treat // the pending tasks as failed. if (actor_data.state() == ActorTableData::DEAD && @@ -127,7 +141,8 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks( auto &requests = pending_requests_[actor_id]; while (!requests.empty()) { const auto &request = *requests.front(); - PushTask(*client, request, TaskID::FromBinary(request.task_spec().task_id()), + PushTask(*client, request, actor_id, + TaskID::FromBinary(request.task_spec().task_id()), request.task_spec().num_returns()); requests.pop_front(); } @@ -135,12 +150,18 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks( void CoreWorkerDirectActorTaskSubmitter::PushTask(rpc::DirectActorClient &client, const rpc::PushTaskRequest &request, + const ActorID &actor_id, const TaskID &task_id, int num_returns) { RAY_LOG(DEBUG) << "Push task " << task_id; - auto status = client.PushTask( - request, - [this, task_id, num_returns](Status status, const rpc::PushTaskReply &reply) { + waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns)); + auto status = + client.PushTask(request, [this, actor_id, task_id, num_returns]( + Status status, const rpc::PushTaskReply &reply) { + { + std::unique_lock guard(mutex_); + waiting_reply_tasks_[actor_id].erase(task_id); + } if (!status.ok()) { TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED); return; @@ -173,6 +194,8 @@ void CoreWorkerDirectActorTaskSubmitter::PushTask(rpc::DirectActorClient &client void CoreWorkerDirectActorTaskSubmitter::TreatTaskAsFailed( const TaskID &task_id, int num_returns, const rpc::ErrorType &error_type) { + RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id + << ", error_type: " << ErrorType_Name(error_type); for (int i = 0; i < num_returns; i++) { const auto object_id = ObjectID::ForTaskReturn( task_id, /*index=*/i + 1, diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 056615ff0da0b..0f14969a3306b 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -48,11 +48,12 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter { /// /// \param[in] client The RPC client to send tasks to an actor. /// \param[in] request The request to send. + /// \param[in] actor_id Actor ID. /// \param[in] task_id The ID of a task. /// \param[in] num_returns Number of return objects. /// \return Void. void PushTask(rpc::DirectActorClient &client, const rpc::PushTaskRequest &request, - const TaskID &task_id, int num_returns); + const ActorID &actor_id, const TaskID &task_id, int num_returns); /// Treat a task as failed. /// @@ -110,6 +111,9 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter { std::unordered_map>> pending_requests_; + /// Map from actor id to the tasks that are waiting for reply. + std::unordered_map> waiting_reply_tasks_; + /// The store provider. std::unique_ptr store_provider_; From 748938d36cd524c162a91544adb81f9ffbe7b2a4 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sat, 24 Aug 2019 02:02:04 +0800 Subject: [PATCH 06/24] Update ActorTest.java --- .../src/main/java/org/ray/api/test/ActorTest.java | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index 850939af51905..979c19147fbbe 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -48,20 +48,6 @@ public int accessLargeObject(LargeObject largeObject) { } } - @RayRemote - public static class Caller { - - private final RayActor counter; - - public Caller(RayActor counter) { - this.counter = counter; - } - - public int call() { - return Ray.call(Counter::increase, counter, 1).get(); - } - } - @Test public void testCreateAndCallActor() { // Test creating an actor from a constructor From fe0ea89d6a931167ac869a5fc9ab34509ff7d84a Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sat, 24 Aug 2019 11:12:58 +0800 Subject: [PATCH 07/24] Fix duplicated put in memory store provider --- .../core_worker/store_provider/memory_store_provider.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/ray/core_worker/store_provider/memory_store_provider.cc b/src/ray/core_worker/store_provider/memory_store_provider.cc index e83a5a30713d8..4be054c920033 100644 --- a/src/ray/core_worker/store_provider/memory_store_provider.cc +++ b/src/ray/core_worker/store_provider/memory_store_provider.cc @@ -18,7 +18,13 @@ CoreWorkerMemoryStoreProvider::CoreWorkerMemoryStoreProvider( Status CoreWorkerMemoryStoreProvider::Put(const RayObject &object, const ObjectID &object_id) { - return store_->Put(object_id, object); + Status status = store_->Put(object_id, object); + if (status.IsKeyError()) { + RAY_LOG(WARNING) << "Trying to put an object that already existed in memory: " + << object_id << "."; + return Status::OK(); + } + return status; } Status CoreWorkerMemoryStoreProvider::Get( From 82ca22bcec0c4db9262f1cba8f8ee7d69b794125 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Mon, 26 Aug 2019 20:45:09 +0800 Subject: [PATCH 08/24] Direct actor call subscribe to individual actor updates --- .../main/java/org/ray/api/test/StressTest.java | 2 -- .../transport/direct_actor_transport.cc | 13 ++++++++++--- .../transport/direct_actor_transport.h | 18 +++++++----------- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/java/test/src/main/java/org/ray/api/test/StressTest.java b/java/test/src/main/java/org/ray/api/test/StressTest.java index a1dc5116e055d..e2efecbf222e1 100644 --- a/java/test/src/main/java/org/ray/api/test/StressTest.java +++ b/java/test/src/main/java/org/ray/api/test/StressTest.java @@ -75,8 +75,6 @@ public int ping(int n) { @Test public void testSubmittingManyTasksToOneActor() { TestUtils.skipTestUnderSingleProcess(); - // TODO (kfstorm): Don't know why it hangs. - TestUtils.skipTestIfDirectActorCallEnabled(); RayActor actor = Ray.createActor(Actor::new); List objectIds = new ArrayList<>(); for (int i = 0; i < 10; i++) { diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 62f0489505496..ef8f409dd0fdf 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -22,7 +22,6 @@ CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter( gcs_client_(gcs_client), client_call_manager_(io_service), store_provider_(std::move(store_provider)) { - RAY_CHECK_OK(SubscribeActorUpdates()); } Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( @@ -42,6 +41,12 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage()); std::unique_lock guard(mutex_); + + if (subscribed_actors_.find(actor_id) == subscribed_actors_.end()) { + RAY_CHECK_OK(SubscribeActorUpdates(actor_id)); + subscribed_actors_.insert(actor_id); + } + auto iter = actor_states_.find(actor_id); if (iter == actor_states_.end() || iter->second.state_ == ActorTableData::RECONSTRUCTING) { @@ -75,7 +80,8 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( } } -Status CoreWorkerDirectActorTaskSubmitter::SubscribeActorUpdates() { +Status CoreWorkerDirectActorTaskSubmitter::SubscribeActorUpdates( + const ActorID &actor_id) { // Register a callback to handle actor notifications. auto actor_notification_callback = [this](const ActorID &actor_id, const ActorTableData &actor_data) { @@ -127,7 +133,8 @@ Status CoreWorkerDirectActorTaskSubmitter::SubscribeActorUpdates() { << ", port: " << actor_data.port(); }; - return gcs_client_.Actors().AsyncSubscribe(actor_notification_callback, nullptr); + return gcs_client_.Actors().AsyncSubscribe(actor_id, actor_notification_callback, + nullptr); } void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks( diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 0f14969a3306b..ec5b2d77a260c 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -2,6 +2,7 @@ #define RAY_CORE_WORKER_DIRECT_ACTOR_TRANSPORT_H #include +#include #include "ray/core_worker/object_interface.h" #include "ray/core_worker/transport/transport.h" @@ -39,8 +40,8 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter { Status SubmitTask(const TaskSpecification &task_spec) override; private: - /// Subscribe to all actor updates. - Status SubscribeActorUpdates(); + /// Subscribe to updates of an actor. + Status SubscribeActorUpdates(const ActorID &actor_id); /// Push a task to a remote actor via the given client. /// Note, this function doesn't return any error status code. If an error occurs while @@ -93,18 +94,10 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter { /// Mutex to proect the various maps below. mutable std::mutex mutex_; - /// Map from actor id to actor state. This currently includes all actors in the system. - /// - /// TODO(zhijunfu): this map currently keeps track of all the actors in the system, - /// like `actor_registry_` in raylet. Later after new GCS client interface supports - /// subscribing updates for a specific actor, this will be updated to only include - /// entries for actors that the transport submits tasks to. + /// Map from actor id to actor state. This only includes actors that we send tasks to. std::unordered_map actor_states_; /// Map from actor id to rpc client. This only includes actors that we send tasks to. - /// - /// TODO(zhijunfu): this will be moved into `actor_states_` later when we can - /// subscribe updates for a specific actor. std::unordered_map> rpc_clients_; /// Map from actor id to the actor's pending requests. @@ -114,6 +107,9 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter { /// Map from actor id to the tasks that are waiting for reply. std::unordered_map> waiting_reply_tasks_; + /// The set of actors which are subscribed for further updates. + std::unordered_set subscribed_actors_; + /// The store provider. std::unique_ptr store_provider_; From c0da33811d04344d839ba47293e6130e09d24774 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Mon, 26 Aug 2019 21:26:31 +0800 Subject: [PATCH 09/24] Fix IsActorAlive --- src/ray/core_worker/transport/direct_actor_transport.cc | 8 +++++++- src/ray/core_worker/transport/direct_actor_transport.h | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index ef8f409dd0fdf..63af6be552948 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -214,8 +214,14 @@ void CoreWorkerDirectActorTaskSubmitter::TreatTaskAsFailed( } } -bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) const { +bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) { std::unique_lock guard(mutex_); + + if (subscribed_actors_.find(actor_id) == subscribed_actors_.end()) { + RAY_CHECK_OK(SubscribeActorUpdates(actor_id)); + subscribed_actors_.insert(actor_id); + } + auto iter = actor_states_.find(actor_id); return (iter != actor_states_.end() && iter->second.state_ == ActorTableData::ALIVE); } diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index ec5b2d77a260c..c280130c0e66e 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -80,7 +80,7 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter { /// /// \param[in] actor_id The actor ID. /// \return Whether this actor is alive. - bool IsActorAlive(const ActorID &actor_id) const; + bool IsActorAlive(const ActorID &actor_id); /// The IO event loop. boost::asio::io_service &io_service_; From 779c49c650340fdbd60de626401999a44a0012e9 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Tue, 27 Aug 2019 15:51:50 +0800 Subject: [PATCH 10/24] Skip testActorProcessDying because it hangs --- java/test/src/main/java/org/ray/api/test/FailureTest.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/java/test/src/main/java/org/ray/api/test/FailureTest.java b/java/test/src/main/java/org/ray/api/test/FailureTest.java index b47b010ae25a8..96aa9c6c07f4c 100644 --- a/java/test/src/main/java/org/ray/api/test/FailureTest.java +++ b/java/test/src/main/java/org/ray/api/test/FailureTest.java @@ -105,6 +105,8 @@ public void testWorkerProcessDying() { @Test public void testActorProcessDying() { TestUtils.skipTestUnderSingleProcess(); + // This test case hangs if the worker to worker connection is implemented with grpc. + TestUtils.skipTestIfDirectActorCallEnabled(); RayActor actor = Ray.createActor(BadActor::new, false); try { Ray.call(BadActor::badMethod2, actor).get(); From a1600bc5fff7253e692db7028d2e91e765cc669c Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Wed, 28 Aug 2019 13:59:41 +0800 Subject: [PATCH 11/24] Add todos for skipped tests. --- .../java/org/ray/api/options/ActorCreationOptions.java | 10 ++++++---- .../java/org/ray/api/test/ActorReconstructionTest.java | 2 +- .../test/src/main/java/org/ray/api/test/ActorTest.java | 1 + .../org/ray/api/test/CrossLanguageInvocationTest.java | 1 + .../src/main/java/org/ray/api/test/FailureTest.java | 1 + 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index 69918edc3e34c..949c16ad1d22c 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -47,10 +47,12 @@ public Builder setMaxReconstructions(int maxReconstructions) { return this; } - public Builder setIsDirectCall(boolean isDirectCall) { - this.isDirectCall = isDirectCall; - return this; - } + // Since direct call is not fully supported yet, users are not allowed to set the option to true. + // TODO (kfstorm): uncomment when direct call is ready. +// public Builder setIsDirectCall(boolean isDirectCall) { +// this.isDirectCall = isDirectCall; +// return this; +// } public Builder setJvmOptions(String jvmOptions) { this.jvmOptions = jvmOptions; diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java index 94e7a5cd8b44a..5f4677c77fd61 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java @@ -47,7 +47,7 @@ public int getPid() { @Test public void testActorReconstruction() throws InterruptedException, IOException { TestUtils.skipTestUnderSingleProcess(); - // No lineage cache when direct actor call is enabled. + // By design. No lineage cache when direct actor call is enabled. TestUtils.skipTestIfDirectActorCallEnabled(); ActorCreationOptions options = new ActorCreationOptions.Builder().setMaxReconstructions(1).createActorCreationOptions(); diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index 979c19147fbbe..cfe5382530f45 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -124,6 +124,7 @@ public void testForkingActorHandle() { public void testUnreconstructableActorObject() throws InterruptedException { TestUtils.skipTestUnderSingleProcess(); // The UnreconstructableException is created by raylet. + // TODO (kfstorm): This should be supported by direct actor call. TestUtils.skipTestIfDirectActorCallEnabled(); RayActor counter = Ray.createActor(Counter::new, 100); // Call an actor method. diff --git a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java index f69c520fec157..bfc996349ce00 100644 --- a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java +++ b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java @@ -50,6 +50,7 @@ public void testCallingPythonFunction() { public void testCallingPythonActor() { // Direct actor call only allows passing arguments as values. // However, bytes arguments are passed from Java to Python as references. + // TODO (kfstorm): This should be supported once passing by value with metadata is allowed. TestUtils.skipTestIfDirectActorCallEnabled(); RayPyActor actor = Ray.createPyActor(PYTHON_MODULE, "Counter", "1".getBytes()); RayObject res = Ray.callPy(actor, "increase", "1".getBytes()); diff --git a/java/test/src/main/java/org/ray/api/test/FailureTest.java b/java/test/src/main/java/org/ray/api/test/FailureTest.java index 96aa9c6c07f4c..99a33cb6656d0 100644 --- a/java/test/src/main/java/org/ray/api/test/FailureTest.java +++ b/java/test/src/main/java/org/ray/api/test/FailureTest.java @@ -106,6 +106,7 @@ public void testWorkerProcessDying() { public void testActorProcessDying() { TestUtils.skipTestUnderSingleProcess(); // This test case hangs if the worker to worker connection is implemented with grpc. + // TODO (kfstorm): Should be fixed. TestUtils.skipTestIfDirectActorCallEnabled(); RayActor actor = Ray.createActor(BadActor::new, false); try { From 1b7837f4c2229060415c58cf1d0faba3b239dedc Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Fri, 6 Sep 2019 21:55:20 +0800 Subject: [PATCH 12/24] Address comments --- .../org/ray/runtime/AbstractRayRuntime.java | 16 --- .../java/org/ray/runtime/RayDevRuntime.java | 16 ++- .../org/ray/runtime/RayNativeRuntime.java | 18 +++- .../runtime/raylet/LocalModeRayletClient.java | 30 ------ .../runtime/raylet/NativeRayletClient.java | 59 ---------- .../org/ray/runtime/raylet/RayletClient.java | 16 --- .../ray/runtime/task/ArgumentsBuilder.java | 19 ++-- .../runtime/task/LocalModeTaskExecutor.java | 22 ++++ .../runtime/task/LocalModeTaskSubmitter.java | 4 +- .../ray/runtime/task/NativeTaskExecutor.java | 102 ++++++++++++++++++ .../org/ray/runtime/task/TaskExecutor.java | 84 +-------------- .../java/org_ray_runtime_RayNativeRuntime.cc | 19 ++++ .../java/org_ray_runtime_RayNativeRuntime.h | 8 ++ ...g_ray_runtime_raylet_NativeRayletClient.cc | 75 ------------- ...rg_ray_runtime_raylet_NativeRayletClient.h | 40 ------- ...org_ray_runtime_task_NativeTaskExecutor.cc | 55 ++++++++++ .../org_ray_runtime_task_NativeTaskExecutor.h | 33 ++++++ .../transport/direct_actor_transport.cc | 2 +- 18 files changed, 281 insertions(+), 337 deletions(-) delete mode 100644 java/runtime/src/main/java/org/ray/runtime/raylet/LocalModeRayletClient.java delete mode 100644 java/runtime/src/main/java/org/ray/runtime/raylet/NativeRayletClient.java delete mode 100644 java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java create mode 100644 java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java create mode 100644 java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java delete mode 100644 src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc delete mode 100644 src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h create mode 100644 src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc create mode 100644 src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.h diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 39462bd609bca..ed1c9bcc6ee9c 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -12,7 +12,6 @@ import org.ray.api.function.RayFunc; import org.ray.api.function.RayFuncVoid; import org.ray.api.id.ObjectId; -import org.ray.api.id.UniqueId; import org.ray.api.options.ActorCreationOptions; import org.ray.api.options.CallOptions; import org.ray.api.runtime.RayRuntime; @@ -28,7 +27,6 @@ import org.ray.runtime.generated.Common.Language; import org.ray.runtime.object.ObjectStore; import org.ray.runtime.object.RayObjectImpl; -import org.ray.runtime.raylet.RayletClient; import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.FunctionArg; import org.ray.runtime.task.TaskExecutor; @@ -51,7 +49,6 @@ public abstract class AbstractRayRuntime implements RayRuntime { protected ObjectStore objectStore; protected TaskSubmitter taskSubmitter; - protected RayletClient rayletClient; protected WorkerContext workerContext; public AbstractRayRuntime(RayConfig rayConfig) { @@ -90,15 +87,6 @@ public void free(List objectIds, boolean localOnly, boolean deleteCrea objectStore.delete(objectIds, localOnly, deleteCreatingTasks); } - @Override - public void setResource(String resourceName, double capacity, UniqueId nodeId) { - Preconditions.checkArgument(Double.compare(capacity, 0) >= 0); - if (nodeId == null) { - nodeId = UniqueId.NIL; - } - rayletClient.setResource(resourceName, capacity, nodeId); - } - @Override public WaitResult wait(List> waitList, int numReturns, int timeoutMs) { return objectStore.wait(waitList, numReturns, timeoutMs); @@ -225,10 +213,6 @@ public ObjectStore getObjectStore() { return objectStore; } - public RayletClient getRayletClient() { - return rayletClient; - } - public FunctionManager getFunctionManager() { return functionManager; } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java index 7653177a1cf53..c2ac883b320a6 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java @@ -2,15 +2,19 @@ import java.util.concurrent.atomic.AtomicInteger; import org.ray.api.id.JobId; +import org.ray.api.id.UniqueId; import org.ray.runtime.config.RayConfig; import org.ray.runtime.context.LocalModeWorkerContext; import org.ray.runtime.object.LocalModeObjectStore; -import org.ray.runtime.raylet.LocalModeRayletClient; +import org.ray.runtime.task.LocalModeTaskExecutor; import org.ray.runtime.task.LocalModeTaskSubmitter; -import org.ray.runtime.task.TaskExecutor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class RayDevRuntime extends AbstractRayRuntime { + private static final Logger LOGGER = LoggerFactory.getLogger(RayDevRuntime.class); + public RayDevRuntime(RayConfig rayConfig) { super(rayConfig); } @@ -22,14 +26,13 @@ public void start() { if (rayConfig.getJobId().isNil()) { rayConfig.setJobId(nextJobId()); } - taskExecutor = new TaskExecutor(this); + taskExecutor = new LocalModeTaskExecutor(this); workerContext = new LocalModeWorkerContext(rayConfig.getJobId()); objectStore = new LocalModeObjectStore(workerContext); taskSubmitter = new LocalModeTaskSubmitter(this, (LocalModeObjectStore) objectStore, rayConfig.numberExecThreadsForDevRuntime); ((LocalModeObjectStore) objectStore).addObjectPutCallback( objectId -> ((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId)); - rayletClient = new LocalModeRayletClient(); } @Override @@ -37,6 +40,11 @@ public void shutdown() { taskExecutor = null; } + @Override + public void setResource(String resourceName, double capacity, UniqueId nodeId) { + LOGGER.error("Not implemented under SINGLE_PROCESS mode."); + } + private JobId nextJobId() { return JobId.fromInt(jobCounter.getAndIncrement()); } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index 28a0d0828f0ac..151ce77d79431 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -6,6 +6,7 @@ import java.util.HashMap; import java.util.Map; import org.ray.api.id.JobId; +import org.ray.api.id.UniqueId; import org.ray.runtime.config.RayConfig; import org.ray.runtime.context.NativeWorkerContext; import org.ray.runtime.gcs.GcsClient; @@ -13,8 +14,8 @@ import org.ray.runtime.gcs.RedisClient; import org.ray.runtime.generated.Common.WorkerType; import org.ray.runtime.object.NativeObjectStore; -import org.ray.runtime.raylet.NativeRayletClient; import org.ray.runtime.runner.RunManager; +import org.ray.runtime.task.NativeTaskExecutor; import org.ray.runtime.task.NativeTaskSubmitter; import org.ray.runtime.task.TaskExecutor; import org.ray.runtime.util.FileUtil; @@ -103,11 +104,10 @@ public void start() { new GcsClientOptions(rayConfig)); Preconditions.checkState(nativeCoreWorkerPointer != 0); - taskExecutor = new TaskExecutor(this); + taskExecutor = new NativeTaskExecutor(nativeCoreWorkerPointer, this); workerContext = new NativeWorkerContext(nativeCoreWorkerPointer); objectStore = new NativeObjectStore(workerContext, nativeCoreWorkerPointer); taskSubmitter = new NativeTaskSubmitter(nativeCoreWorkerPointer); - rayletClient = new NativeRayletClient(nativeCoreWorkerPointer); // register registerWorker(); @@ -127,6 +127,15 @@ public void shutdown() { } } + @Override + public void setResource(String resourceName, double capacity, UniqueId nodeId) { + Preconditions.checkArgument(Double.compare(capacity, 0) >= 0); + if (nodeId == null) { + nodeId = UniqueId.NIL; + } + nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes()); + } + public void run() { nativeRunTaskExecutor(nativeCoreWorkerPointer, taskExecutor); } @@ -167,4 +176,7 @@ private static native void nativeRunTaskExecutor(long nativeCoreWorkerPointer, private static native void nativeSetup(String logDir); private static native void nativeShutdownHook(); + + private static native void nativeSetResource(long conn, String resourceName, double capacity, + byte[] nodeId); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/LocalModeRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/LocalModeRayletClient.java deleted file mode 100644 index a9976893095a4..0000000000000 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/LocalModeRayletClient.java +++ /dev/null @@ -1,30 +0,0 @@ -package org.ray.runtime.raylet; - -import org.apache.commons.lang3.NotImplementedException; -import org.ray.api.id.ActorId; -import org.ray.api.id.UniqueId; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Raylet client for local mode. - */ -public class LocalModeRayletClient implements RayletClient { - - private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeRayletClient.class); - - @Override - public UniqueId prepareCheckpoint(ActorId actorId, boolean isDirectCall) { - throw new NotImplementedException("Not implemented."); - } - - @Override - public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) { - throw new NotImplementedException("Not implemented."); - } - - @Override - public void setResource(String resourceName, double capacity, UniqueId nodeId) { - LOGGER.error("Not implemented under SINGLE_PROCESS mode."); - } -} diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/NativeRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/NativeRayletClient.java deleted file mode 100644 index e423aa8301306..0000000000000 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/NativeRayletClient.java +++ /dev/null @@ -1,59 +0,0 @@ -package org.ray.runtime.raylet; - -import org.ray.api.exception.RayException; -import org.ray.api.id.ActorId; -import org.ray.api.id.UniqueId; - -/** - * Raylet client for cluster mode. This is a wrapper class for C++ RayletClient. - */ -public class NativeRayletClient implements RayletClient { - - /** - * The native pointer of core worker. - */ - private long nativeCoreWorkerPointer = 0; - - public NativeRayletClient(long nativeCoreWorkerPointer) { - this.nativeCoreWorkerPointer = nativeCoreWorkerPointer; - } - - @Override - public UniqueId prepareCheckpoint(ActorId actorId, boolean isDirectCall) { - return new UniqueId( - nativePrepareCheckpoint(nativeCoreWorkerPointer, actorId.getBytes(), isDirectCall)); - } - - @Override - public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) { - nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, actorId.getBytes(), - checkpointId.getBytes()); - } - - - public void setResource(String resourceName, double capacity, UniqueId nodeId) { - nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes()); - } - - /// Native method declarations. - /// - /// If you change the signature of any native methods, please re-generate - /// the C++ header file and update the C++ implementation accordingly: - /// - /// Suppose that $Dir is your ray root directory. - /// 1) pushd $Dir/java/runtime/target/classes - /// 2) javah -classpath .:$Dir/java/api/target/classes org.ray.runtime.raylet.NativeRayletClient - /// 3) clang-format -i org_ray_runtime_raylet_NativeRayletClient.h - /// 4) cp org_ray_runtime_raylet_NativeRayletClient.h $Dir/src/ray/core_worker/lib/java/ - /// 5) vim $Dir/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc - /// 6) popd - - private static native byte[] nativePrepareCheckpoint(long conn, byte[] actorId, - boolean isDirectCall); - - private static native void nativeNotifyActorResumedFromCheckpoint(long conn, byte[] actorId, - byte[] checkpointId); - - private static native void nativeSetResource(long conn, String resourceName, double capacity, - byte[] nodeId) throws RayException; -} diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java deleted file mode 100644 index 12e2bc712bd99..0000000000000 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ /dev/null @@ -1,16 +0,0 @@ -package org.ray.runtime.raylet; - -import org.ray.api.id.ActorId; -import org.ray.api.id.UniqueId; - -/** - * Client to the Raylet backend. - */ -public interface RayletClient { - - UniqueId prepareCheckpoint(ActorId actorId, boolean isDirectCall); - - void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId); - - void setResource(String resourceName, double capacity, UniqueId nodeId); -} diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index 175a99d30a783..d72863ecf8258 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -33,13 +33,17 @@ public static List wrap(Object[] args, boolean crossLanguage, if (arg == null) { data = Serializer.encode(null); } else if (arg instanceof RayObject) { - throwExceptionIfIsDirectActorCall(isDirectActorCall, - "Passing RayObject to a direct call actor is not supported."); + if (isDirectActorCall) { + throw new IllegalArgumentException( + "Passing RayObject to a direct call actor is not supported."); + } id = ((RayObject) arg).getId(); } else if (arg instanceof byte[] && crossLanguage) { // TODO (kfstorm): This could be supported once we supported passing by value with metadata. - throwExceptionIfIsDirectActorCall(isDirectActorCall, - "Passing raw bytes to a direct call actor is not supported."); + if (isDirectActorCall) { + throw new IllegalArgumentException( + "Passing raw bytes to a direct call actor is not supported."); + } // If the argument is a byte array and will be used by a different language, // do not inline this argument. Because the other language doesn't know how // to deserialize it. @@ -62,13 +66,6 @@ public static List wrap(Object[] args, boolean crossLanguage, return ret; } - private static void throwExceptionIfIsDirectActorCall(boolean isDirectActorCall, String message) { - if (isDirectActorCall) { - throw new IllegalArgumentException( - message != null ? message : "Direct actor call only supports by-value arguments."); - } - } - /** * Convert list of NativeRayObject to real function arguments. */ diff --git a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java new file mode 100644 index 0000000000000..19016f8221fcc --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java @@ -0,0 +1,22 @@ +package org.ray.runtime.task; + +import org.ray.api.id.ActorId; +import org.ray.runtime.AbstractRayRuntime; + +/** + * Task executor for local mode. + */ +public class LocalModeTaskExecutor extends TaskExecutor { + + public LocalModeTaskExecutor(AbstractRayRuntime runtime) { + super(runtime); + } + + @Override + protected void maybeSaveCheckpoint(Object actor, ActorId actorId, boolean isDirectCall) { + } + + @Override + protected void maybeLoadCheckpoint(Object actor, ActorId actorId) { + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java index b5f22e8d39b24..c8bbacfa1cfc2 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java @@ -95,12 +95,12 @@ private TaskExecutor getTaskExecutor(TaskSpec task) { if (task.getType() == TaskType.ACTOR_TASK) { taskExecutor = actorTaskExecutors.get(getActorId(task)); } else if (task.getType() == TaskType.ACTOR_CREATION_TASK) { - taskExecutor = new TaskExecutor(runtime); + taskExecutor = new LocalModeTaskExecutor(runtime); actorTaskExecutors.put(getActorId(task), taskExecutor); } else if (idleTaskExecutors.size() > 0) { taskExecutor = idleTaskExecutors.pop(); } else { - taskExecutor = new TaskExecutor(runtime); + taskExecutor = new LocalModeTaskExecutor(runtime); } } currentTaskExecutor.set(taskExecutor); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java new file mode 100644 index 0000000000000..ca190205842c0 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java @@ -0,0 +1,102 @@ +package org.ray.runtime.task; + +import com.google.common.base.Preconditions; +import java.util.ArrayList; +import java.util.List; +import org.ray.api.Checkpointable; +import org.ray.api.Checkpointable.Checkpoint; +import org.ray.api.Checkpointable.CheckpointContext; +import org.ray.api.id.ActorId; +import org.ray.api.id.UniqueId; +import org.ray.runtime.AbstractRayRuntime; + +/** + * Task executor for cluster mode. + */ +public class NativeTaskExecutor extends TaskExecutor { + + // TODO(hchen): Use the C++ config. + private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20; + + /** + * The native pointer of core worker. + */ + private final long nativeCoreWorkerPointer; + + /** + * Number of tasks executed since last actor checkpoint. + */ + private int numTasksSinceLastCheckpoint = 0; + + /** + * IDs of this actor's previous checkpoints. + */ + private List checkpointIds; + + /** + * Timestamp of the last actor checkpoint. + */ + private long lastCheckpointTimestamp = 0; + + public NativeTaskExecutor(long nativeCoreWorkerPointer, AbstractRayRuntime runtime) { + super(runtime); + this.nativeCoreWorkerPointer = nativeCoreWorkerPointer; + } + + @Override + protected void maybeSaveCheckpoint(Object actor, ActorId actorId, boolean isDirectCall) { + if (!(actor instanceof Checkpointable)) { + return; + } + CheckpointContext checkpointContext = new CheckpointContext(actorId, + ++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp); + Checkpointable checkpointable = (Checkpointable) actor; + if (!checkpointable.shouldCheckpoint(checkpointContext)) { + return; + } + numTasksSinceLastCheckpoint = 0; + lastCheckpointTimestamp = System.currentTimeMillis(); + UniqueId checkpointId = new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer)); + checkpointIds.add(checkpointId); + if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) { + ((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0)); + checkpointIds.remove(0); + } + checkpointable.saveCheckpoint(actorId, checkpointId); + } + + @Override + protected void maybeLoadCheckpoint(Object actor, ActorId actorId) { + if (!(actor instanceof Checkpointable)) { + return; + } + numTasksSinceLastCheckpoint = 0; + lastCheckpointTimestamp = System.currentTimeMillis(); + checkpointIds = new ArrayList<>(); + List availableCheckpoints + = runtime.getGcsClient().getCheckpointsForActor(actorId); + if (availableCheckpoints.isEmpty()) { + return; + } + UniqueId checkpointId = ((Checkpointable) actor).loadCheckpoint(actorId, availableCheckpoints); + if (checkpointId != null) { + boolean checkpointValid = false; + for (Checkpoint checkpoint : availableCheckpoints) { + if (checkpoint.checkpointId.equals(checkpointId)) { + checkpointValid = true; + break; + } + } + Preconditions.checkArgument(checkpointValid, + "'loadCheckpoint' must return a checkpoint ID that exists in the " + + "'availableCheckpoints' list, or null."); + + nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, checkpointId.getBytes()); + } + } + + private static native byte[] nativePrepareCheckpoint(long nativeCoreWorkerPointer); + + private static native void nativeNotifyActorResumedFromCheckpoint(long nativeCoreWorkerPointer, + byte[] checkpointId); +} diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java index db190cc77b258..2bac8ca9030fe 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java @@ -3,16 +3,11 @@ import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.List; -import org.ray.api.Checkpointable; -import org.ray.api.Checkpointable.Checkpoint; -import org.ray.api.Checkpointable.CheckpointContext; import org.ray.api.exception.RayTaskException; import org.ray.api.id.ActorId; import org.ray.api.id.JobId; import org.ray.api.id.TaskId; -import org.ray.api.id.UniqueId; import org.ray.runtime.AbstractRayRuntime; -import org.ray.runtime.config.RunMode; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.runtime.functionmanager.RayFunction; import org.ray.runtime.generated.Common.TaskType; @@ -23,13 +18,10 @@ /** * The task executor, which executes tasks assigned by raylet continuously. */ -public final class TaskExecutor { +public abstract class TaskExecutor { private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class); - // TODO(hchen): Use the C++ config. - private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20; - protected final AbstractRayRuntime runtime; /** @@ -42,22 +34,7 @@ public final class TaskExecutor { */ private Exception actorCreationException = null; - /** - * Number of tasks executed since last actor checkpoint. - */ - private int numTasksSinceLastCheckpoint = 0; - - /** - * IDs of this actor's previous checkpoints. - */ - private List checkpointIds; - - /** - * Timestamp of the last actor checkpoint. - */ - private long lastCheckpointTimestamp = 0; - - public TaskExecutor(AbstractRayRuntime runtime) { + protected TaskExecutor(AbstractRayRuntime runtime) { this.runtime = runtime; } @@ -134,60 +111,7 @@ private JavaFunctionDescriptor parseFunctionDescriptor(List rayFunctionI rayFunctionInfo.get(2)); } - private void maybeSaveCheckpoint(Object actor, ActorId actorId, boolean isDirectCall) { - if (!(actor instanceof Checkpointable)) { - return; - } - if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) { - // Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet. - return; - } - CheckpointContext checkpointContext = new CheckpointContext(actorId, - ++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp); - Checkpointable checkpointable = (Checkpointable) actor; - if (!checkpointable.shouldCheckpoint(checkpointContext)) { - return; - } - numTasksSinceLastCheckpoint = 0; - lastCheckpointTimestamp = System.currentTimeMillis(); - UniqueId checkpointId = runtime.getRayletClient().prepareCheckpoint(actorId, isDirectCall); - checkpointIds.add(checkpointId); - if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) { - ((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0)); - checkpointIds.remove(0); - } - checkpointable.saveCheckpoint(actorId, checkpointId); - } + protected abstract void maybeSaveCheckpoint(Object actor, ActorId actorId, boolean isDirectCall); - private void maybeLoadCheckpoint(Object actor, ActorId actorId) { - if (!(actor instanceof Checkpointable)) { - return; - } - if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) { - // Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet. - return; - } - numTasksSinceLastCheckpoint = 0; - lastCheckpointTimestamp = System.currentTimeMillis(); - checkpointIds = new ArrayList<>(); - List availableCheckpoints - = runtime.getGcsClient().getCheckpointsForActor(actorId); - if (availableCheckpoints.isEmpty()) { - return; - } - UniqueId checkpointId = ((Checkpointable) actor).loadCheckpoint(actorId, availableCheckpoints); - if (checkpointId != null) { - boolean checkpointValid = false; - for (Checkpoint checkpoint : availableCheckpoints) { - if (checkpoint.checkpointId.equals(checkpointId)) { - checkpointValid = true; - break; - } - } - Preconditions.checkArgument(checkpointValid, - "'loadCheckpoint' must return a checkpoint ID that exists in the " - + "'availableCheckpoints' list, or null."); - runtime.getRayletClient().notifyActorResumedFromCheckpoint(actorId, checkpointId); - } - } + protected abstract void maybeLoadCheckpoint(Object actor, ActorId actorId); } diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc index c1c545e8b74e1..abd898f556e0f 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc @@ -129,6 +129,25 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook( ray::RayLog::ShutDownRayLog(); } +/* + * Class: org_ray_runtime_RayNativeRuntime + * Method: nativeSetResource + * Signature: (JLjava/lang/String;D[B)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetResource( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jstring resourceName, + jdouble capacity, jbyteArray nodeId) { + const auto node_id = JavaByteArrayToId(env, nodeId); + const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE); + + auto &raylet_client = + reinterpret_cast(nativeCoreWorkerPointer)->GetRayletClient(); + auto status = raylet_client.SetResource(native_resource_name, + static_cast(capacity), node_id); + env->ReleaseStringUTFChars(resourceName, native_resource_name); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); +} + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h index c71fec9829e32..4805646406422 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h @@ -48,6 +48,14 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup(JNIEnv JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook(JNIEnv *, jclass); +/* + * Class: org_ray_runtime_RayNativeRuntime + * Method: nativeSetResource + * Signature: (JLjava/lang/String;D[B)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetResource( + JNIEnv *, jclass, jlong, jstring, jdouble, jbyteArray); + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc b/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc deleted file mode 100644 index ddb907126ce99..0000000000000 --- a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc +++ /dev/null @@ -1,75 +0,0 @@ -#include "ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h" -#include -#include "ray/common/id.h" -#include "ray/core_worker/common.h" -#include "ray/core_worker/core_worker.h" -#include "ray/core_worker/lib/java/jni_utils.h" -#include "ray/raylet/raylet_client.h" - -inline RayletClient &GetRayletClientFromPointer(jlong nativeCoreWorkerPointer) { - return reinterpret_cast(nativeCoreWorkerPointer)->GetRayletClient(); -} - -#ifdef __cplusplus -extern "C" { -#endif - -using ray::ClientID; - -/* - * Class: org_ray_runtime_raylet_NativeRayletClient - * Method: nativePrepareCheckpoint - * Signature: (J[BZ)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_NativeRayletClient_nativePrepareCheckpoint( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray actorId, - jboolean isDirectCall) { - const auto actor_id = JavaByteArrayToId(env, actorId); - ActorCheckpointID checkpoint_id; - auto status = GetRayletClientFromPointer(nativeCoreWorkerPointer) - .PrepareActorCheckpoint(actor_id, isDirectCall, checkpoint_id); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); - jbyteArray result = env->NewByteArray(checkpoint_id.Size()); - env->SetByteArrayRegion(result, 0, checkpoint_id.Size(), - reinterpret_cast(checkpoint_id.Data())); - return result; -} - -/* - * Class: org_ray_runtime_raylet_NativeRayletClient - * Method: nativeNotifyActorResumedFromCheckpoint - * Signature: (J[B[B)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_NativeRayletClient_nativeNotifyActorResumedFromCheckpoint( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray actorId, - jbyteArray checkpointId) { - const auto actor_id = JavaByteArrayToId(env, actorId); - const auto checkpoint_id = JavaByteArrayToId(env, checkpointId); - auto status = GetRayletClientFromPointer(nativeCoreWorkerPointer) - .NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -/* - * Class: org_ray_runtime_raylet_NativeRayletClient - * Method: nativeSetResource - * Signature: (JLjava/lang/String;D[B)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_NativeRayletClient_nativeSetResource( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jstring resourceName, - jdouble capacity, jbyteArray nodeId) { - const auto node_id = JavaByteArrayToId(env, nodeId); - const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE); - - auto status = - GetRayletClientFromPointer(nativeCoreWorkerPointer) - .SetResource(native_resource_name, static_cast(capacity), node_id); - env->ReleaseStringUTFChars(resourceName, native_resource_name); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -#ifdef __cplusplus -} -#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h b/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h deleted file mode 100644 index 9f7f67ef9f282..0000000000000 --- a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h +++ /dev/null @@ -1,40 +0,0 @@ -/* DO NOT EDIT THIS FILE - it is machine generated */ -#include -/* Header for class org_ray_runtime_raylet_NativeRayletClient */ - -#ifndef _Included_org_ray_runtime_raylet_NativeRayletClient -#define _Included_org_ray_runtime_raylet_NativeRayletClient -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: org_ray_runtime_raylet_NativeRayletClient - * Method: nativePrepareCheckpoint - * Signature: (J[BZ)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_NativeRayletClient_nativePrepareCheckpoint(JNIEnv *, jclass, - jlong, jbyteArray, - jboolean); - -/* - * Class: org_ray_runtime_raylet_NativeRayletClient - * Method: nativeNotifyActorResumedFromCheckpoint - * Signature: (J[B[B)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_NativeRayletClient_nativeNotifyActorResumedFromCheckpoint( - JNIEnv *, jclass, jlong, jbyteArray, jbyteArray); - -/* - * Class: org_ray_runtime_raylet_NativeRayletClient - * Method: nativeSetResource - * Signature: (JLjava/lang/String;D[B)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_NativeRayletClient_nativeSetResource( - JNIEnv *, jclass, jlong, jstring, jdouble, jbyteArray); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc new file mode 100644 index 0000000000000..c9d8e22cc73ee --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc @@ -0,0 +1,55 @@ +#include "ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.h" +#include +#include "ray/common/id.h" +#include "ray/core_worker/common.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/lib/java/jni_utils.h" +#include "ray/raylet/raylet_client.h" + +#ifdef __cplusplus +extern "C" { +#endif + +using ray::ClientID; + +/* + * Class: org_ray_runtime_task_NativeTaskExecutor + * Method: nativePrepareCheckpoint + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) { + auto &core_worker = *reinterpret_cast(nativeCoreWorkerPointer); + const auto &actor_id = core_worker.GetWorkerContext().GetCurrentActorID(); + const auto &task_spec = core_worker.GetWorkerContext().GetCurrentTask(); + RAY_CHECK(task_spec->IsActorTask()); + ActorCheckpointID checkpoint_id; + auto status = core_worker.GetRayletClient().PrepareActorCheckpoint( + actor_id, task_spec->IsDirectCall(), checkpoint_id); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + jbyteArray result = env->NewByteArray(checkpoint_id.Size()); + env->SetByteArrayRegion(result, 0, checkpoint_id.Size(), + reinterpret_cast(checkpoint_id.Data())); + return result; +} + +/* + * Class: org_ray_runtime_task_NativeTaskExecutor + * Method: nativeNotifyActorResumedFromCheckpoint + * Signature: (J[B)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_task_NativeTaskExecutor_nativeNotifyActorResumedFromCheckpoint( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray checkpointId) { + auto &core_worker = *reinterpret_cast(nativeCoreWorkerPointer); + const auto &actor_id = core_worker.GetWorkerContext().GetCurrentActorID(); + const auto checkpoint_id = JavaByteArrayToId(env, checkpointId); + auto status = core_worker.GetRayletClient().NotifyActorResumedFromCheckpoint( + actor_id, checkpoint_id); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); +} + +#ifdef __cplusplus +} +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.h b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.h new file mode 100644 index 0000000000000..c51bd22e1b3ef --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.h @@ -0,0 +1,33 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_ray_runtime_task_NativeTaskExecutor */ + +#ifndef _Included_org_ray_runtime_task_NativeTaskExecutor +#define _Included_org_ray_runtime_task_NativeTaskExecutor +#ifdef __cplusplus +extern "C" { +#endif +#undef org_ray_runtime_task_NativeTaskExecutor_NUM_ACTOR_CHECKPOINTS_TO_KEEP +#define org_ray_runtime_task_NativeTaskExecutor_NUM_ACTOR_CHECKPOINTS_TO_KEEP 20L +/* + * Class: org_ray_runtime_task_NativeTaskExecutor + * Method: nativePrepareCheckpoint + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(JNIEnv *, jclass, + jlong); + +/* + * Class: org_ray_runtime_task_NativeTaskExecutor + * Method: nativeNotifyActorResumedFromCheckpoint + * Signature: (J[B)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_task_NativeTaskExecutor_nativeNotifyActorResumedFromCheckpoint( + JNIEnv *, jclass, jlong, jbyteArray); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 63af6be552948..0dabc27621aaf 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -160,7 +160,7 @@ void CoreWorkerDirectActorTaskSubmitter::PushTask(rpc::DirectActorClient &client const ActorID &actor_id, const TaskID &task_id, int num_returns) { - RAY_LOG(DEBUG) << "Push task " << task_id; + RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id; waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns)); auto status = client.PushTask(request, [this, actor_id, task_id, num_returns]( From fa2d4fd589301e7a628c6511e46ffed2bd0f00ce Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Fri, 6 Sep 2019 22:48:25 +0800 Subject: [PATCH 13/24] Unskip testActorReconstruction --- .../main/java/org/ray/api/test/ActorReconstructionTest.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java index 5f4677c77fd61..17a16333ef82f 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java @@ -47,8 +47,6 @@ public int getPid() { @Test public void testActorReconstruction() throws InterruptedException, IOException { TestUtils.skipTestUnderSingleProcess(); - // By design. No lineage cache when direct actor call is enabled. - TestUtils.skipTestIfDirectActorCallEnabled(); ActorCreationOptions options = new ActorCreationOptions.Builder().setMaxReconstructions(1).createActorCreationOptions(); RayActor actor = Ray.createActor(Counter::new, options); @@ -67,7 +65,7 @@ public void testActorReconstruction() throws InterruptedException, IOException { // Try calling increase on this actor again and check the value is now 4. int value = Ray.call(Counter::increase, actor).get(); - Assert.assertEquals(value, 4); + Assert.assertEquals(value, options.isDirectCall ? 1 : 4); Assert.assertTrue(Ray.call(Counter::wasCurrentActorReconstructed, actor).get()); From a932190931de9a62de9934d64e1ce2c987a5a107 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Fri, 6 Sep 2019 23:20:22 +0800 Subject: [PATCH 14/24] Cleanup --- .../ray/runtime/context/LocalModeWorkerContext.java | 5 ----- .../org/ray/runtime/context/NativeWorkerContext.java | 7 ------- .../java/org/ray/runtime/context/WorkerContext.java | 5 ----- .../org/ray/runtime/task/LocalModeTaskExecutor.java | 2 +- .../java/org/ray/runtime/task/NativeTaskExecutor.java | 2 +- .../main/java/org/ray/runtime/task/TaskExecutor.java | 7 +++---- .../org_ray_runtime_context_NativeWorkerContext.cc | 11 ----------- .../org_ray_runtime_context_NativeWorkerContext.h | 9 --------- 8 files changed, 5 insertions(+), 43 deletions(-) diff --git a/java/runtime/src/main/java/org/ray/runtime/context/LocalModeWorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/context/LocalModeWorkerContext.java index 06b1d578b8218..1f05c3d5989b7 100644 --- a/java/runtime/src/main/java/org/ray/runtime/context/LocalModeWorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/context/LocalModeWorkerContext.java @@ -41,11 +41,6 @@ public ActorId getCurrentActorId() { return LocalModeTaskSubmitter.getActorId(taskSpec); } - @Override - public boolean getIsDirectCall() { - return false; - } - @Override public ClassLoader getCurrentClassLoader() { return null; diff --git a/java/runtime/src/main/java/org/ray/runtime/context/NativeWorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/context/NativeWorkerContext.java index 62188a952ca97..b42a7b23411db 100644 --- a/java/runtime/src/main/java/org/ray/runtime/context/NativeWorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/context/NativeWorkerContext.java @@ -38,11 +38,6 @@ public ActorId getCurrentActorId() { return ActorId.fromByteBuffer(nativeGetCurrentActorId(nativeCoreWorkerPointer)); } - @Override - public boolean getIsDirectCall() { - return nativeGetIsDirectCall(nativeCoreWorkerPointer); - } - @Override public ClassLoader getCurrentClassLoader() { return currentClassLoader; @@ -74,6 +69,4 @@ public TaskId getCurrentTaskId() { private static native ByteBuffer nativeGetCurrentWorkerId(long nativeCoreWorkerPointer); private static native ByteBuffer nativeGetCurrentActorId(long nativeCoreWorkerPointer); - - private static native boolean nativeGetIsDirectCall(long nativeCoreWorkerPointer); } diff --git a/java/runtime/src/main/java/org/ray/runtime/context/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/context/WorkerContext.java index 23894e0548db4..4a526c85ecbfb 100644 --- a/java/runtime/src/main/java/org/ray/runtime/context/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/context/WorkerContext.java @@ -26,11 +26,6 @@ public interface WorkerContext { */ ActorId getCurrentActorId(); - /** - * Whether the current task is a direct call task. - */ - boolean getIsDirectCall(); - /** * The class loader that is associated with the current job. It's used for locating classes when * dealing with serialization and deserialization in {@link org.ray.runtime.util.Serializer}. diff --git a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java index 19016f8221fcc..24e6f15b98fbf 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java @@ -13,7 +13,7 @@ public LocalModeTaskExecutor(AbstractRayRuntime runtime) { } @Override - protected void maybeSaveCheckpoint(Object actor, ActorId actorId, boolean isDirectCall) { + protected void maybeSaveCheckpoint(Object actor, ActorId actorId) { } @Override diff --git a/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java index ca190205842c0..36e7259a4b9d1 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java @@ -44,7 +44,7 @@ public NativeTaskExecutor(long nativeCoreWorkerPointer, AbstractRayRuntime runti } @Override - protected void maybeSaveCheckpoint(Object actor, ActorId actorId, boolean isDirectCall) { + protected void maybeSaveCheckpoint(Object actor, ActorId actorId) { if (!(actor instanceof Checkpointable)) { return; } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java index 2bac8ca9030fe..ad54a9a352f50 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java @@ -76,8 +76,7 @@ protected List execute(List rayFunctionInfo, if (taskType != TaskType.ACTOR_CREATION_TASK) { if (taskType == TaskType.ACTOR_TASK) { // TODO (kfstorm): handle checkpoint in core worker. - maybeSaveCheckpoint(actor, runtime.getWorkerContext().getCurrentActorId(), - runtime.getWorkerContext().getIsDirectCall()); + maybeSaveCheckpoint(actor, runtime.getWorkerContext().getCurrentActorId()); } if (rayFunction.hasReturn()) { returnObjects.add(runtime.getObjectStore().serialize(result)); @@ -91,7 +90,7 @@ protected List execute(List rayFunctionInfo, } catch (Exception e) { LOGGER.error("Error executing task " + taskId, e); if (taskType != TaskType.ACTOR_CREATION_TASK) { - if(rayFunction.hasReturn()) { + if (rayFunction.hasReturn()) { returnObjects.add(runtime.getObjectStore() .serialize(new RayTaskException("Error executing task " + taskId, e))); } @@ -111,7 +110,7 @@ private JavaFunctionDescriptor parseFunctionDescriptor(List rayFunctionI rayFunctionInfo.get(2)); } - protected abstract void maybeSaveCheckpoint(Object actor, ActorId actorId, boolean isDirectCall); + protected abstract void maybeSaveCheckpoint(Object actor, ActorId actorId); protected abstract void maybeLoadCheckpoint(Object actor, ActorId actorId); } diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.cc b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.cc index cc64478e75d15..b7e7910446cd0 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.cc @@ -78,17 +78,6 @@ Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId( return IdToJavaByteBuffer(env, actor_id); } -/* - * Class: org_ray_runtime_context_NativeWorkerContext - * Method: nativeGetIsDirectCall - * Signature: (J)Z - */ -JNIEXPORT jboolean JNICALL -Java_org_ray_runtime_context_NativeWorkerContext_nativeGetIsDirectCall( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) { - return GetWorkerContextFromPointer(nativeCoreWorkerPointer).IsDirectCall(); -} - #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h index b18060f0d772a..fe3725484a0ea 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h @@ -52,15 +52,6 @@ JNIEXPORT jobject JNICALL Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(JNIEnv *, jclass, jlong); -/* - * Class: org_ray_runtime_context_NativeWorkerContext - * Method: nativeGetIsDirectCall - * Signature: (J)Z - */ -JNIEXPORT jboolean JNICALL -Java_org_ray_runtime_context_NativeWorkerContext_nativeGetIsDirectCall(JNIEnv *, jclass, - jlong); - #ifdef __cplusplus } #endif From f5570c7654aa75389e12a5951ecd5e999a2c0975 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sat, 7 Sep 2019 00:12:04 +0800 Subject: [PATCH 15/24] Add is_direct_call to actor creation task and ActorTableData --- .../org/ray/runtime/AbstractRayRuntime.java | 10 +- .../org/ray/runtime/actor/NativeRayActor.java | 6 +- .../ray/runtime/task/ArgumentsBuilder.java | 8 +- python/ray/_raylet.pyx | 2 +- python/ray/includes/libraylet.pxd | 1 - python/ray/includes/task.pxd | 6 +- python/ray/includes/task.pxi | 2 +- src/ray/common/task/task_spec.cc | 10 +- src/ray/common/task/task_util.h | 17 +- src/ray/core_worker/context.cc | 6 +- src/ray/core_worker/context.h | 5 +- .../org_ray_runtime_actor_NativeRayActor.cc | 4 +- .../org_ray_runtime_actor_NativeRayActor.h | 4 +- ...org_ray_runtime_task_NativeTaskExecutor.cc | 2 +- src/ray/core_worker/task_execution.cc | 7 +- src/ray/core_worker/task_interface.cc | 6 +- src/ray/core_worker/test/core_worker_test.cc | 3 +- .../transport/direct_actor_transport.cc | 16 +- .../transport/direct_actor_transport.h | 5 +- .../core_worker/transport/raylet_transport.cc | 11 +- .../core_worker/transport/raylet_transport.h | 5 +- src/ray/protobuf/common.proto | 4 +- src/ray/protobuf/gcs.proto | 2 + src/ray/raylet/format/node_manager.fbs | 2 - ...org_ray_runtime_raylet_RayletClientImpl.cc | 292 ------------------ src/ray/raylet/node_manager.cc | 4 +- src/ray/raylet/raylet_client.cc | 5 +- src/ray/raylet/raylet_client.h | 3 +- src/ray/raylet/worker.cc | 1 + 29 files changed, 87 insertions(+), 362 deletions(-) delete mode 100644 src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index ed1c9bcc6ee9c..71f24131cad92 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -160,7 +160,7 @@ public RayPyActor createPyActor(String moduleName, String className, Object[] ar private RayObject callNormalFunction(FunctionDescriptor functionDescriptor, Object[] args, int numReturns, CallOptions options) { List functionArgs = ArgumentsBuilder - .wrap(args, functionDescriptor.getLanguage() != Language.JAVA, /*isDirectActorCall*/false); + .wrap(args, functionDescriptor.getLanguage() != Language.JAVA, /*isDirectCall*/false); List returnIds = taskSubmitter.submitTask(functionDescriptor, functionArgs, numReturns, options); Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1); @@ -174,7 +174,7 @@ private RayObject callNormalFunction(FunctionDescriptor functionDescriptor, private RayObject callActorFunction(RayActor rayActor, FunctionDescriptor functionDescriptor, Object[] args, int numReturns) { List functionArgs = ArgumentsBuilder - .wrap(args, functionDescriptor.getLanguage() != Language.JAVA, isDirectActorCall(rayActor)); + .wrap(args, functionDescriptor.getLanguage() != Language.JAVA, isDirectCall(rayActor)); List returnIds = taskSubmitter.submitActorTask(rayActor, functionDescriptor, functionArgs, numReturns, null); Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1); @@ -188,7 +188,7 @@ private RayObject callActorFunction(RayActor rayActor, private RayActor createActorImpl(FunctionDescriptor functionDescriptor, Object[] args, ActorCreationOptions options) { List functionArgs = ArgumentsBuilder - .wrap(args, functionDescriptor.getLanguage() != Language.JAVA, /*isDirectActorCall*/false); + .wrap(args, functionDescriptor.getLanguage() != Language.JAVA, /*isDirectCall*/false); if (functionDescriptor.getLanguage() != Language.JAVA && options != null) { Preconditions.checkState(Strings.isNullOrEmpty(options.jvmOptions)); } @@ -198,9 +198,9 @@ private RayActor createActorImpl(FunctionDescriptor functionDescriptor, return actor; } - private boolean isDirectActorCall(RayActor rayActor) { + private boolean isDirectCall(RayActor rayActor) { if (rayActor instanceof NativeRayActor) { - return ((NativeRayActor) rayActor).isDirectCall(); + return ((NativeRayActor) rayActor).isDirectCallActor(); } return false; } diff --git a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java index cbe13c092a411..8dd7ac8c34f63 100644 --- a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java +++ b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java @@ -51,8 +51,8 @@ public Language getLanguage() { return Language.forNumber(nativeGetLanguage(nativeActorHandle)); } - public boolean isDirectCall() { - return nativeIsDirectCall(nativeActorHandle); + public boolean isDirectCallActor() { + return nativeIsDirectCallActor(nativeActorHandle); } @Override @@ -94,7 +94,7 @@ protected void finalize() { private static native int nativeGetLanguage(long nativeActorHandle); - private static native boolean nativeIsDirectCall(long nativeActorHandle); + private static native boolean nativeIsDirectCallActor(long nativeActorHandle); private static native List nativeGetActorCreationTaskFunctionDescriptor( long nativeActorHandle); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index d72863ecf8258..56f87c353a1d8 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -25,7 +25,7 @@ public class ArgumentsBuilder { * Convert real function arguments to task spec arguments. */ public static List wrap(Object[] args, boolean crossLanguage, - boolean isDirectActorCall) { + boolean isDirectCall) { List ret = new ArrayList<>(); for (Object arg : args) { ObjectId id = null; @@ -33,14 +33,14 @@ public static List wrap(Object[] args, boolean crossLanguage, if (arg == null) { data = Serializer.encode(null); } else if (arg instanceof RayObject) { - if (isDirectActorCall) { + if (isDirectCall) { throw new IllegalArgumentException( "Passing RayObject to a direct call actor is not supported."); } id = ((RayObject) arg).getId(); } else if (arg instanceof byte[] && crossLanguage) { // TODO (kfstorm): This could be supported once we supported passing by value with metadata. - if (isDirectActorCall) { + if (isDirectCall) { throw new IllegalArgumentException( "Passing raw bytes to a direct call actor is not supported."); } @@ -50,7 +50,7 @@ public static List wrap(Object[] args, boolean crossLanguage, id = Ray.put(arg).getId(); } else { byte[] serialized = Serializer.encode(arg); - if (!isDirectActorCall && serialized.length > LARGEST_SIZE_PASS_BY_VALUE) { + if (!isDirectCall && serialized.length > LARGEST_SIZE_PASS_BY_VALUE) { id = ((AbstractRayRuntime) Ray.internal()).getObjectStore() .put(new NativeRayObject(serialized, null)); } else { diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 47b90942b2dc0..a1e9387fe86fe 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -357,7 +357,7 @@ cdef class RayletClient: # the GIL so other Python threads can run. with nogil: check_status(self.client.get().PrepareActorCheckpoint( - c_actor_id, False, checkpoint_id)) + c_actor_id, checkpoint_id)) return ActorCheckpointID(checkpoint_id.Binary()) def notify_actor_resumed_from_checkpoint(self, ActorID actor_id, diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd index e5dcf6efe1689..45746da2fdf43 100644 --- a/python/ray/includes/libraylet.pxd +++ b/python/ray/includes/libraylet.pxd @@ -68,7 +68,6 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: CRayStatus FreeObjects(const c_vector[CObjectID] &object_ids, c_bool local_only, c_bool delete_creating_tasks) CRayStatus PrepareActorCheckpoint(const CActorID &actor_id, - c_bool is_direct_call, CActorCheckpointID &checkpoint_id) CRayStatus NotifyActorResumedFromCheckpoint( const CActorID &actor_id, const CActorCheckpointID &checkpoint_id) diff --git a/python/ray/includes/task.pxd b/python/ray/includes/task.pxd index 40baabe9f0b8c..fa828d1e5f307 100644 --- a/python/ray/includes/task.pxd +++ b/python/ray/includes/task.pxd @@ -90,15 +90,15 @@ cdef extern from "ray/common/task/task_util.h" namespace "ray" nogil: TaskSpecBuilder &SetActorCreationTaskSpec( const CActorID &actor_id, uint64_t max_reconstructions, - const c_vector[c_string] &dynamic_worker_options) + const c_vector[c_string] &dynamic_worker_options, + c_bool is_direct_call) TaskSpecBuilder &SetActorTaskSpec( const CActorID &actor_id, const CActorHandleID &actor_handle_id, const CObjectID &actor_creation_dummy_object_id, const CObjectID &previous_actor_task_dummy_object_id, uint64_t actor_counter, - const c_vector[CActorHandleID] &new_handle_ids, - c_bool is_direct_call); + const c_vector[CActorHandleID] &new_handle_ids); RpcTaskSpec GetMessage() diff --git a/python/ray/includes/task.pxi b/python/ray/includes/task.pxi index d052a791026ee..fe32b6ce70405 100644 --- a/python/ray/includes/task.pxi +++ b/python/ray/includes/task.pxi @@ -79,6 +79,7 @@ cdef class TaskSpec: actor_creation_id.native(), max_actor_reconstructions, [], + False, ) elif not actor_id.is_nil(): # Actor task. @@ -92,7 +93,6 @@ cdef class TaskSpec: previous_actor_task_dummy_object_id.native(), actor_counter, c_new_actor_handles, - False, ) else: # Normal task. diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index d6172810c82f7..e547133c6fc9e 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -147,8 +147,8 @@ std::vector TaskSpecification::NewActorHandles() const { } bool TaskSpecification::IsDirectCall() const { - RAY_CHECK(IsActorTask()); - return message_->actor_task_spec().is_direct_call(); + RAY_CHECK(IsActorCreationTask()); + return message_->actor_creation_task_spec().is_direct_call(); } std::string TaskSpecification::DebugString() const { @@ -174,13 +174,13 @@ std::string TaskSpecification::DebugString() const { if (IsActorCreationTask()) { // Print actor creation task spec. stream << ", actor_creation_task_spec={actor_id=" << ActorCreationId() - << ", max_reconstructions=" << MaxActorReconstructions() << "}"; + << ", max_reconstructions=" << MaxActorReconstructions() + << ", is_direct_call=" << IsDirectCall() << "}"; } else if (IsActorTask()) { // Print actor task spec. stream << ", actor_task_spec={actor_id=" << ActorId() << ", actor_handle_id=" << ActorHandleId() - << ", actor_counter=" << ActorCounter() - << ", is_direct_call=" << IsDirectCall() << "}"; + << ", actor_counter=" << ActorCounter() << "}"; } return stream.str(); diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index a343e8da97aec..d26039c6f3be7 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -78,7 +78,8 @@ class TaskSpecBuilder { /// \return Reference to the builder object itself. TaskSpecBuilder &SetActorCreationTaskSpec( const ActorID &actor_id, uint64_t max_reconstructions = 0, - const std::vector &dynamic_worker_options = {}) { + const std::vector &dynamic_worker_options = {}, + bool is_direct_call = false) { message_->set_type(TaskType::ACTOR_CREATION_TASK); auto actor_creation_spec = message_->mutable_actor_creation_task_spec(); actor_creation_spec->set_actor_id(actor_id.Binary()); @@ -86,6 +87,7 @@ class TaskSpecBuilder { for (const auto &option : dynamic_worker_options) { actor_creation_spec->add_dynamic_worker_options(option); } + actor_creation_spec->set_is_direct_call(is_direct_call); return *this; } @@ -93,13 +95,11 @@ class TaskSpecBuilder { /// See `common.proto` for meaning of the arguments. /// /// \return Reference to the builder object itself. - TaskSpecBuilder &SetActorTaskSpec(const ActorID &actor_id, - const ActorHandleID &actor_handle_id, - const ObjectID &actor_creation_dummy_object_id, - const ObjectID &previous_actor_task_dummy_object_id, - uint64_t actor_counter, - const std::vector &new_handle_ids = {}, - bool is_direct_call = false) { + TaskSpecBuilder &SetActorTaskSpec( + const ActorID &actor_id, const ActorHandleID &actor_handle_id, + const ObjectID &actor_creation_dummy_object_id, + const ObjectID &previous_actor_task_dummy_object_id, uint64_t actor_counter, + const std::vector &new_handle_ids = {}) { message_->set_type(TaskType::ACTOR_TASK); auto actor_spec = message_->mutable_actor_task_spec(); actor_spec->set_actor_id(actor_id.Binary()); @@ -112,7 +112,6 @@ class TaskSpecBuilder { for (const auto &id : new_handle_ids) { actor_spec->add_new_actor_handles(id.Binary()); } - actor_spec->set_is_direct_call(is_direct_call); return *this; } diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 9f214d2b918ed..9a8941624ce28 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -80,6 +80,7 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { if (task_spec.IsActorCreationTask()) { RAY_CHECK(current_actor_id_.IsNil()); current_actor_id_ = task_spec.ActorCreationId(); + is_direct_call_actor_ = task_spec.IsDirectCall(); } if (task_spec.IsActorTask()) { RAY_CHECK(current_actor_id_ == task_spec.ActorId()); @@ -91,10 +92,7 @@ std::shared_ptr WorkerContext::GetCurrentTask() const { const ActorID &WorkerContext::GetCurrentActorID() const { return current_actor_id_; } -bool WorkerContext::IsDirectCall() const { - std::shared_ptr task = GetThreadContext().GetCurrentTask(); - return task && task->IsActorTask() && task->IsDirectCall(); -} +bool WorkerContext::IsDirectCallActor() const { return is_direct_call_actor_; } WorkerThreadContext &WorkerContext::GetThreadContext() { if (thread_context_ == nullptr) { diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 40bdf1bade6bf..7082c43020ffe 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -26,7 +26,7 @@ class WorkerContext { const ActorID &GetCurrentActorID() const; - bool IsDirectCall() const; + bool IsDirectCallActor() const; int GetNextTaskIndex(); @@ -45,6 +45,9 @@ class WorkerContext { /// ID of current actor. ActorID current_actor_id_; + /// Whether direct actor call is used. + bool is_direct_call_actor_; + private: static WorkerThreadContext &GetThreadContext(); diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc index f632f280f75b4..1a91e0fb036b2 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc @@ -59,10 +59,10 @@ JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLangua /* * Class: org_ray_runtime_actor_NativeRayActor - * Method: nativeIsDirectCall + * Method: nativeIsDirectCallActor * Signature: (J)Z */ -JNIEXPORT jboolean JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeIsDirectCall( +JNIEXPORT jboolean JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeIsDirectCallActor( JNIEnv *env, jclass o, jlong nativeActorHandle) { return GetActorHandle(nativeActorHandle).IsDirectCallActor(); } diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h index 92363773c8d70..245064fcff447 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h @@ -42,11 +42,11 @@ Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage(JNIEnv *, jclass, jl /* * Class: org_ray_runtime_actor_NativeRayActor - * Method: nativeIsDirectCall + * Method: nativeIsDirectCallActor * Signature: (J)Z */ JNIEXPORT jboolean JNICALL -Java_org_ray_runtime_actor_NativeRayActor_nativeIsDirectCall(JNIEnv *, jclass, jlong); +Java_org_ray_runtime_actor_NativeRayActor_nativeIsDirectCallActor(JNIEnv *, jclass, jlong); /* * Class: org_ray_runtime_actor_NativeRayActor diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc index c9d8e22cc73ee..8658c2f8a1b8d 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc @@ -26,7 +26,7 @@ Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint( RAY_CHECK(task_spec->IsActorTask()); ActorCheckpointID checkpoint_id; auto status = core_worker.GetRayletClient().PrepareActorCheckpoint( - actor_id, task_spec->IsDirectCall(), checkpoint_id); + actor_id, checkpoint_id); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); jbyteArray result = env->NewByteArray(checkpoint_id.Size()); env->SetByteArrayRegion(result, 0, checkpoint_id.Size(), diff --git a/src/ray/core_worker/task_execution.cc b/src/ray/core_worker/task_execution.cc index f397ab314df3b..2040d4de1755c 100644 --- a/src/ray/core_worker/task_execution.cc +++ b/src/ray/core_worker/task_execution.cc @@ -22,12 +22,13 @@ CoreWorkerTaskExecutionInterface::CoreWorkerTaskExecutionInterface( task_receivers_.emplace( TaskTransportType::RAYLET, std::unique_ptr(new CoreWorkerRayletTaskReceiver( - raylet_client, object_interface_, *main_service_, worker_server_, func))); + worker_context_, raylet_client, object_interface_, *main_service_, + worker_server_, func))); task_receivers_.emplace( TaskTransportType::DIRECT_ACTOR, std::unique_ptr( - new CoreWorkerDirectActorTaskReceiver(object_interface_, *main_service_, - worker_server_, func))); + new CoreWorkerDirectActorTaskReceiver(worker_context_, object_interface_, + *main_service_, worker_server_, func))); // Start RPC server after all the task receivers are properly initialized. worker_server_.Run(); diff --git a/src/ray/core_worker/task_interface.cc b/src/ray/core_worker/task_interface.cc index be588178164ce..001481180260c 100644 --- a/src/ray/core_worker/task_interface.cc +++ b/src/ray/core_worker/task_interface.cc @@ -173,7 +173,8 @@ Status CoreWorkerTaskInterface::CreateActor( actor_creation_options.resources, actor_creation_options.resources, TaskTransportType::RAYLET, &return_ids); builder.SetActorCreationTaskSpec(actor_id, actor_creation_options.max_reconstructions, - actor_creation_options.dynamic_worker_options); + actor_creation_options.dynamic_worker_options, + actor_creation_options.is_direct_call); *actor_handle = std::unique_ptr(new ActorHandle( actor_id, ActorHandleID::Nil(), function.language, @@ -217,8 +218,7 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, actor_handle.ActorID(), actor_handle.ActorHandleID(), actor_creation_dummy_object_id, /*previous_actor_task_dummy_object_id=*/actor_handle.ActorCursor(), - actor_handle.IncreaseTaskCounter(), actor_handle.NewActorHandles(), - actor_handle.IsDirectCallActor()); + actor_handle.IncreaseTaskCounter(), actor_handle.NewActorHandles()); // Manipulate actor handle state. auto actor_cursor = (*return_ids).back(); diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index c33d9d6772fbf..b973e7c774a0b 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -674,8 +674,7 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { builder.SetActorTaskSpec( actor_handle.ActorID(), actor_handle.ActorHandleID(), actor_creation_dummy_object_id, - /*previous_actor_task_dummy_object_id=*/actor_handle.ActorCursor(), 0, {}, - actor_handle.IsDirectCallActor()); + /*previous_actor_task_dummy_object_id=*/actor_handle.ActorCursor(), 0, {}); const auto &task_spec = builder.Build(); diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 0dabc27621aaf..dd42432374369 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -21,8 +21,7 @@ CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter( : io_service_(io_service), gcs_client_(gcs_client), client_call_manager_(io_service), - store_provider_(std::move(store_provider)) { -} + store_provider_(std::move(store_provider)) {} Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( const TaskSpecification &task_spec) { @@ -227,9 +226,11 @@ bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) { } CoreWorkerDirectActorTaskReceiver::CoreWorkerDirectActorTaskReceiver( - CoreWorkerObjectInterface &object_interface, boost::asio::io_service &io_service, - rpc::GrpcServer &server, const TaskHandler &task_handler) - : object_interface_(object_interface), + WorkerContext &worker_context, CoreWorkerObjectInterface &object_interface, + boost::asio::io_service &io_service, rpc::GrpcServer &server, + const TaskHandler &task_handler) + : worker_context_(worker_context), + object_interface_(object_interface), task_service_(io_service, *this), task_handler_(task_handler) { server.RegisterService(task_service_); @@ -246,6 +247,11 @@ void CoreWorkerDirectActorTaskReceiver::HandlePushTask( nullptr); return; } + if (task_spec.IsActorTask() && !worker_context_.IsDirectCallActor()) { + send_reply_callback(Status::Invalid("This actor doesn't accept direct calls."), + nullptr, nullptr); + return; + } auto num_returns = task_spec.NumReturns(); RAY_CHECK(task_spec.IsActorCreationTask() || task_spec.IsActorTask()); diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index c280130c0e66e..effa7b01df1a9 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -119,7 +119,8 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter { class CoreWorkerDirectActorTaskReceiver : public CoreWorkerTaskReceiver, public rpc::DirectActorHandler { public: - CoreWorkerDirectActorTaskReceiver(CoreWorkerObjectInterface &object_interface, + CoreWorkerDirectActorTaskReceiver(WorkerContext &worker_context, + CoreWorkerObjectInterface &object_interface, boost::asio::io_service &io_service, rpc::GrpcServer &server, const TaskHandler &task_handler); @@ -135,6 +136,8 @@ class CoreWorkerDirectActorTaskReceiver : public CoreWorkerTaskReceiver, rpc::SendReplyCallback send_reply_callback) override; private: + // Worker context. + WorkerContext &worker_context_; // Object interface. CoreWorkerObjectInterface &object_interface_; /// The rpc service for `DirectActorService`. diff --git a/src/ray/core_worker/transport/raylet_transport.cc b/src/ray/core_worker/transport/raylet_transport.cc index 004434f9d0a17..f99ba90e66dbe 100644 --- a/src/ray/core_worker/transport/raylet_transport.cc +++ b/src/ray/core_worker/transport/raylet_transport.cc @@ -14,10 +14,11 @@ Status CoreWorkerRayletTaskSubmitter::SubmitTask(const TaskSpecification &task) } CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver( - std::unique_ptr &raylet_client, + WorkerContext &worker_context, std::unique_ptr &raylet_client, CoreWorkerObjectInterface &object_interface, boost::asio::io_service &io_service, rpc::GrpcServer &server, const TaskHandler &task_handler) - : raylet_client_(raylet_client), + : worker_context_(worker_context), + raylet_client_(raylet_client), object_interface_(object_interface), task_service_(io_service, *this), task_handler_(task_handler) { @@ -30,6 +31,12 @@ void CoreWorkerRayletTaskReceiver::HandleAssignTask( const Task task(request.task()); const auto &task_spec = task.GetTaskSpecification(); RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId(); + if (task_spec.IsActorTask() && worker_context_.IsDirectCallActor()) { + send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr, + nullptr); + return; + } + std::vector> results; auto status = task_handler_(task_spec, &results); diff --git a/src/ray/core_worker/transport/raylet_transport.h b/src/ray/core_worker/transport/raylet_transport.h index 0ba8feb5ed644..39a529cde5956 100644 --- a/src/ray/core_worker/transport/raylet_transport.h +++ b/src/ray/core_worker/transport/raylet_transport.h @@ -32,7 +32,8 @@ class CoreWorkerRayletTaskSubmitter : public CoreWorkerTaskSubmitter { class CoreWorkerRayletTaskReceiver : public CoreWorkerTaskReceiver, public rpc::WorkerTaskHandler { public: - CoreWorkerRayletTaskReceiver(std::unique_ptr &raylet_client, + CoreWorkerRayletTaskReceiver(WorkerContext &worker_context, + std::unique_ptr &raylet_client, CoreWorkerObjectInterface &object_interface, boost::asio::io_service &io_service, rpc::GrpcServer &server, const TaskHandler &task_handler); @@ -49,6 +50,8 @@ class CoreWorkerRayletTaskReceiver : public CoreWorkerTaskReceiver, rpc::SendReplyCallback send_reply_callback) override; private: + // Worker context. + WorkerContext &worker_context_; /// Raylet client. std::unique_ptr &raylet_client_; // Object interface. diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index df95593f7c6e2..27f47b605ebf4 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -88,6 +88,8 @@ message ActorCreationTaskSpec { // the placeholder strings (`RAY_WORKER_OPTION_0`, `RAY_WORKER_OPTION_1`, etc) in the // worker command. repeated string dynamic_worker_options = 4; + // Whether direct actor call is used. + bool is_direct_call = 5; } // Task spec of an actor task. @@ -106,8 +108,6 @@ message ActorTaskSpec { repeated bytes new_actor_handles = 6; // The dummy object ID of the previous actor task. bytes previous_actor_task_dummy_object_id = 7; - // Whether direct actor call is used. - bool is_direct_call = 8; } // The task execution specification encapsulates all mutable information about diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 218e1acd577a3..8041cc40fe4fc 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -109,6 +109,8 @@ message ActorTableData { string ip_address = 9; // The port that the actor is listening on. int32 port = 10; + // Whether direct actor call is used. + bool is_direct_call = 11; } message ErrorTableData { diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 60e20a5aea6e9..705a9fdba9dd3 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -219,8 +219,6 @@ table FreeObjectsRequest { table PrepareActorCheckpointRequest { // ID of the actor. actor_id: string; - // Whether direct actor call is used. - is_direct_call: bool; } table PrepareActorCheckpointReply { diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc deleted file mode 100644 index a9ef670b930da..0000000000000 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ /dev/null @@ -1,292 +0,0 @@ -#include "ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h" - -#include - -#include "ray/common/id.h" -#include "ray/core_worker/lib/java/jni_utils.h" -#include "ray/raylet/raylet_client.h" -#include "ray/util/logging.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeInit - * Signature: (Ljava/lang/String;[BZ[B)J - */ -JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit( - JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker, - jbyteArray jobId) { - const auto worker_id = JavaByteArrayToId(env, workerId); - const auto job_id = JavaByteArrayToId(env, jobId); - const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE); - auto raylet_client = new std::unique_ptr( - new RayletClient(nativeString, worker_id, isWorker, job_id, Language::JAVA)); - env->ReleaseStringUTFChars(sockName, nativeString); - return reinterpret_cast(raylet_client); -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeSubmitTask - * Signature: (J[BLjava/nio/ByteBuffer;II)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask( - JNIEnv *env, jclass, jlong client, jbyteArray taskSpec) { - auto &raylet_client = *reinterpret_cast *>(client); - - jbyte *data = env->GetByteArrayElements(taskSpec, NULL); - jsize size = env->GetArrayLength(taskSpec); - ray::rpc::TaskSpec task_spec_message; - task_spec_message.ParseFromArray(data, size); - env->ReleaseByteArrayElements(taskSpec, data, JNI_ABORT); - - ray::TaskSpecification task_spec(task_spec_message); - auto status = raylet_client->SubmitTask(task_spec); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGetTask - * Signature: (J)[B - */ -JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGetTask( - JNIEnv *env, jclass, jlong client) { - auto &raylet_client = *reinterpret_cast *>(client); - - std::unique_ptr spec; - auto status = raylet_client->GetTask(&spec); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); - - // Serialize the task spec and copy to Java byte array. - auto task_data = spec->Serialize(); - - jbyteArray result = env->NewByteArray(task_data.size()); - if (result == nullptr) { - return nullptr; /* out of memory error thrown */ - } - - env->SetByteArrayRegion(result, 0, task_data.size(), - reinterpret_cast(task_data.data())); - - return result; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeDestroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy( - JNIEnv *env, jclass, jlong client) { - auto raylet_client = reinterpret_cast *>(client); - auto status = (*raylet_client)->Disconnect(); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); - delete raylet_client; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeWaitObject - * Signature: (J[[BIIZ[B)[Z - */ -JNIEXPORT jbooleanArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject( - JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jint numReturns, - jint timeoutMillis, jboolean isWaitLocal, jbyteArray currentTaskId) { - std::vector object_ids; - auto len = env->GetArrayLength(objectIds); - for (int i = 0; i < len; i++) { - jbyteArray object_id_bytes = - static_cast(env->GetObjectArrayElement(objectIds, i)); - const auto object_id = JavaByteArrayToId(env, object_id_bytes); - object_ids.push_back(object_id); - env->DeleteLocalRef(object_id_bytes); - } - const auto current_task_id = JavaByteArrayToId(env, currentTaskId); - - auto &raylet_client = *reinterpret_cast *>(client); - - // Invoke wait. - WaitResultPair result; - auto status = - raylet_client->Wait(object_ids, numReturns, timeoutMillis, - static_cast(isWaitLocal), current_task_id, &result); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); - - // Convert result to java object. - jboolean put_value = true; - jbooleanArray resultArray = env->NewBooleanArray(object_ids.size()); - for (uint i = 0; i < result.first.size(); ++i) { - for (uint j = 0; j < object_ids.size(); ++j) { - if (result.first[i] == object_ids[j]) { - env->SetBooleanArrayRegion(resultArray, j, 1, &put_value); - break; - } - } - } - - put_value = false; - for (uint i = 0; i < result.second.size(); ++i) { - for (uint j = 0; j < object_ids.size(); ++j) { - if (result.second[i] == object_ids[j]) { - env->SetBooleanArrayRegion(resultArray, j, 1, &put_value); - break; - } - } - } - return resultArray; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGenerateActorCreationTaskId - * Signature: ([B[BI)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorCreationTaskId( - JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, - jint parent_task_counter) { - const auto job_id = JavaByteArrayToId(env, jobId); - const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); - - const ActorID actor_id = ray::ActorID::Of(job_id, parent_task_id, parent_task_counter); - const TaskID actor_creation_task_id = ray::TaskID::ForActorCreationTask(actor_id); - jbyteArray result = env->NewByteArray(actor_creation_task_id.Size()); - if (nullptr == result) { - return nullptr; - } - env->SetByteArrayRegion(result, 0, actor_creation_task_id.Size(), - reinterpret_cast(actor_creation_task_id.Data())); - return result; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGenerateActorTaskId - * Signature: ([B[BI[B)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorTaskId( - JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, - jint parent_task_counter, jbyteArray actorId) { - const auto job_id = JavaByteArrayToId(env, jobId); - const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); - const auto actor_id = JavaByteArrayToId(env, actorId); - const TaskID actor_task_id = - ray::TaskID::ForActorTask(job_id, parent_task_id, parent_task_counter, actor_id); - - jbyteArray result = env->NewByteArray(actor_task_id.Size()); - if (nullptr == result) { - return nullptr; - } - env->SetByteArrayRegion(result, 0, actor_task_id.Size(), - reinterpret_cast(actor_task_id.Data())); - return result; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGenerateNormalTaskId - * Signature: ([B[BI)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateNormalTaskId( - JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, - jint parent_task_counter) { - const auto job_id = JavaByteArrayToId(env, jobId); - const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); - const TaskID task_id = - ray::TaskID::ForNormalTask(job_id, parent_task_id, parent_task_counter); - - jbyteArray result = env->NewByteArray(task_id.Size()); - if (nullptr == result) { - return nullptr; - } - env->SetByteArrayRegion(result, 0, task_id.Size(), - reinterpret_cast(task_id.Data())); - return result; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeFreePlasmaObjects - * Signature: (J[[BZZ)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects( - JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jboolean localOnly, - jboolean deleteCreatingTasks) { - std::vector object_ids; - auto len = env->GetArrayLength(objectIds); - for (int i = 0; i < len; i++) { - jbyteArray object_id_bytes = - static_cast(env->GetObjectArrayElement(objectIds, i)); - const auto object_id = JavaByteArrayToId(env, object_id_bytes); - object_ids.push_back(object_id); - env->DeleteLocalRef(object_id_bytes); - } - auto &raylet_client = *reinterpret_cast *>(client); - auto status = raylet_client->FreeObjects(object_ids, localOnly, deleteCreatingTasks); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativePrepareCheckpoint - * Signature: (J[B)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env, jclass, - jlong client, - jbyteArray actorId) { - auto &raylet_client = *reinterpret_cast *>(client); - const auto actor_id = JavaByteArrayToId(env, actorId); - ActorCheckpointID checkpoint_id; - auto status = raylet_client->PrepareActorCheckpoint(actor_id, checkpoint_id); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); - jbyteArray result = env->NewByteArray(checkpoint_id.Size()); - env->SetByteArrayRegion(result, 0, checkpoint_id.Size(), - reinterpret_cast(checkpoint_id.Data())); - return result; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeNotifyActorResumedFromCheckpoint - * Signature: (J[B[B)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint( - JNIEnv *env, jclass, jlong client, jbyteArray actorId, jbyteArray checkpointId) { - auto &raylet_client = *reinterpret_cast *>(client); - const auto actor_id = JavaByteArrayToId(env, actorId); - const auto checkpoint_id = JavaByteArrayToId(env, checkpointId); - auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeSetResource - * Signature: (JLjava/lang/String;D[B)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource( - JNIEnv *env, jclass, jlong client, jstring resourceName, jdouble capacity, - jbyteArray nodeId) { - auto &raylet_client = *reinterpret_cast *>(client); - const auto node_id = JavaByteArrayToId(env, nodeId); - const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE); - - auto status = raylet_client->SetResource(native_resource_name, - static_cast(capacity), node_id); - env->ReleaseStringUTFChars(resourceName, native_resource_name); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -#ifdef __cplusplus -} -#endif diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index dd432ab992da3..83922e0e0441d 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1219,7 +1219,6 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( auto message = flatbuffers::GetRoot(message_data); ActorID actor_id = from_flatbuf(*message->actor_id()); - bool is_direct_call = message->is_direct_call(); RAY_LOG(DEBUG) << "Preparing checkpoint for actor " << actor_id; const auto &actor_entry = actor_registry_.find(actor_id); RAY_CHECK(actor_entry != actor_registry_.end()); @@ -1229,7 +1228,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( ActorCheckpointID checkpoint_id = ActorCheckpointID::FromRandom(); std::shared_ptr checkpoint_data; - if (is_direct_call) { + if (actor_entry->second.GetTableData().is_direct_call()) { checkpoint_data = actor_entry->second.GenerateCheckpointData(actor_entry->first, nullptr); } else { @@ -1922,6 +1921,7 @@ std::shared_ptr NodeManager::CreateActorTableDataFromCreationTas // This is the first time that the actor has been created, so the number // of remaining reconstructions is the max. actor_info_ptr->set_remaining_reconstructions(task_spec.MaxActorReconstructions()); + actor_info_ptr->set_is_direct_call(task_spec.IsDirectCall()); } else { // If we've already seen this actor, it means that this actor was reconstructed. // Thus, its previous state must be RECONSTRUCTING. diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index be1db5297023b..1c8871bf0bd08 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -354,11 +354,10 @@ ray::Status RayletClient::FreeObjects(const std::vector &object_i } ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id, - bool is_direct_call, ActorCheckpointID &checkpoint_id) { flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreatePrepareActorCheckpointRequest( - fbb, to_flatbuf(fbb, actor_id), is_direct_call); + auto message = + ray::protocol::CreatePrepareActorCheckpointRequest(fbb, to_flatbuf(fbb, actor_id)); fbb.Finish(message); std::unique_ptr reply; diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 793e2ea5a57ab..235ba9cfb890b 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -155,10 +155,9 @@ class RayletClient { /// Request raylet backend to prepare a checkpoint for an actor. /// /// \param actor_id ID of the actor. - /// \param is_direct_call Whether direct actor call is used. /// \param checkpoint_id ID of the new checkpoint (output parameter). /// \return ray::Status. - ray::Status PrepareActorCheckpoint(const ActorID &actor_id, bool is_direct_call, + ray::Status PrepareActorCheckpoint(const ActorID &actor_id, ActorCheckpointID &checkpoint_id); /// Notify raylet backend that an actor was resumed from a checkpoint. diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 52814bb708293..eb48b92d53fad 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -130,6 +130,7 @@ void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set, auto status = rpc_client_->AssignTask( request, [](Status status, const rpc::AssignTaskReply &reply) { + RAY_CHECK_OK_PREPEND(status, "Worker failed to finish executing task."); // Worker has finished this task. There's nothing to do here // and assigning new task will be done when raylet receives // `TaskDone` message. From 130e66721fed3063b36e58b79acd873b70a82f1b Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sat, 7 Sep 2019 00:31:18 +0800 Subject: [PATCH 16/24] Update comment --- src/ray/core_worker/context.h | 2 +- src/ray/raylet/actor_registration.cc | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 7082c43020ffe..8317f10119e61 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -45,7 +45,7 @@ class WorkerContext { /// ID of current actor. ActorID current_actor_id_; - /// Whether direct actor call is used. + /// Whether current actor accepts direct calls. bool is_direct_call_actor_; private: diff --git a/src/ray/raylet/actor_registration.cc b/src/ray/raylet/actor_registration.cc index d2d956159163d..575ce095d6385 100644 --- a/src/ray/raylet/actor_registration.cc +++ b/src/ray/raylet/actor_registration.cc @@ -104,8 +104,7 @@ std::shared_ptr ActorRegistration::GenerateCheckpointData( if (task) { const auto actor_handle_id = task->GetTaskSpecification().ActorHandleId(); const auto dummy_object = task->GetTaskSpecification().ActorDummyObject(); - // Extend its frontier to include - // the most recent task. + // Extend its frontier to include the most recent task. // Note(hchen): this is needed because this method is called before // `FinishAssignedTask`, which will be called when the worker tries to fetch // the next task. From 358645630f85be8fc3ce1d2bb50f53419f9caa8a Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sat, 7 Sep 2019 03:05:53 +0800 Subject: [PATCH 17/24] skip some tests not relevant to direct call --- java/BUILD.bazel | 5 +-- java/test.sh | 6 +-- .../src/main/java/org/ray/api/RunTestNG.java | 39 +++++++++++++++++++ .../ray/api/test/ActorReconstructionTest.java | 3 +- .../main/java/org/ray/api/test/ActorTest.java | 15 +++---- .../ray/api/test/BaseMultiLanguageTest.java | 4 +- .../main/java/org/ray/api/test/BaseTest.java | 4 +- .../api/test/CrossLanguageInvocationTest.java | 2 +- .../java/org/ray/api/test/FailureTest.java | 12 +++--- .../org/ray/api/test/MultiThreadingTest.java | 9 ++--- .../java/org/ray/api/test/StressTest.java | 3 +- java/testng.xml | 10 ----- 12 files changed, 64 insertions(+), 48 deletions(-) create mode 100644 java/test/src/main/java/org/ray/api/RunTestNG.java delete mode 100644 java/testng.xml diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 37ef5b93b6eed..99b51b1b9272c 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -2,7 +2,6 @@ load("//bazel:ray.bzl", "define_java_module") load("@build_stack_rules_proto//java:java_proto_compile.bzl", "java_proto_compile") exports_files([ - "testng.xml", "checkstyle.xml", "checkstyle-suppressions.xml", "streaming/testng.xml", @@ -131,9 +130,7 @@ define_java_module( java_binary( name = "all_tests", - main_class = "org.testng.TestNG", - data = ["testng.xml"], - args = ["java/testng.xml"], + main_class = "org.ray.api.RunTestNG", runtime_deps = [ ":org_ray_ray_test", ":org_ray_ray_runtime_test", diff --git a/java/test.sh b/java/test.sh index bc74a4072ab75..7dae453af4f82 100755 --- a/java/test.sh +++ b/java/test.sh @@ -25,14 +25,14 @@ echo "Running tests under cluster mode." # TODO(hchen): Ideally, we should use the following bazel command to run Java tests. However, if there're skipped tests, # TestNG will exit with code 2. And bazel treats it as test failure. # bazel test //java:all_tests --action_env=ENABLE_MULTI_LANGUAGE_TESTS=1 --test_output="errors" || cluster_exit_code=$? -ENABLE_MULTI_LANGUAGE_TESTS=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml +ENABLE_MULTI_LANGUAGE_TESTS=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.ray.api.RunTestNG echo "Running tests under cluster mode with direct actor call turned on." -ENABLE_MULTI_LANGUAGE_TESTS=1 ACTOR_CREATION_OPTIONS_DEFAULT_IS_DIRECT_CALL=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml +ENABLE_MULTI_LANGUAGE_TESTS=1 ACTOR_CREATION_OPTIONS_DEFAULT_IS_DIRECT_CALL=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.ray.api.RunTestNG echo "Running tests under single-process mode." # bazel test //java:all_tests --jvmopt="-Dray.run-mode=SINGLE_PROCESS" --test_output="errors" || single_exit_code=$? -run_testng java -Dray.run-mode="SINGLE_PROCESS" -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml +run_testng java -Dray.run-mode="SINGLE_PROCESS" -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.ray.api.RunTestNG echo "Running streaming tests." run_testng java -cp $ROOT_DIR/../bazel-bin/java/streaming_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/streaming/testng.xml diff --git a/java/test/src/main/java/org/ray/api/RunTestNG.java b/java/test/src/main/java/org/ray/api/RunTestNG.java new file mode 100644 index 0000000000000..b1d85d262cbab --- /dev/null +++ b/java/test/src/main/java/org/ray/api/RunTestNG.java @@ -0,0 +1,39 @@ +package org.ray.api; + +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.List; +import org.ray.api.options.ActorCreationOptions; +import org.testng.TestNG; +import org.testng.xml.XmlGroups; +import org.testng.xml.XmlPackage; +import org.testng.xml.XmlRun; +import org.testng.xml.XmlSuite; +import org.testng.xml.XmlTest; + +public class RunTestNG { + + public static void main(String args[]) { + TestNG testng = new TestNG(); + XmlSuite suite = new XmlSuite(); + suite.setName("RAY suite"); + suite.setVerbose(2); + XmlTest test = new XmlTest(suite); + test.setName("RAY test"); + List packages = new ArrayList<>(); + packages.add(new XmlPackage("org.ray.api.test.*")); + packages.add(new XmlPackage("org.ray.runtime.*")); + test.setPackages(packages); + if (ActorCreationOptions.DEFAULT_IS_DIRECT_CALL) { + XmlGroups groups = new XmlGroups(); + XmlRun run = new XmlRun(); + run.onInclude("directCall"); + groups.setRun(run); + test.setGroups(groups); + } + testng.setXmlSuites(ImmutableList.of(suite)); + testng.setOutputDirectory("/tmp/ray_java_test_output"); + testng.run(); + System.exit(testng.getStatus()); + } +} diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java index 17a16333ef82f..4028030dbaa19 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java @@ -17,6 +17,7 @@ import org.testng.Assert; import org.testng.annotations.Test; +@Test(groups = {"directCall"}) public class ActorReconstructionTest extends BaseTest { @RayRemote() @@ -44,7 +45,6 @@ public int getPid() { } } - @Test public void testActorReconstruction() throws InterruptedException, IOException { TestUtils.skipTestUnderSingleProcess(); ActorCreationOptions options = @@ -125,7 +125,6 @@ public void checkpointExpired(ActorId actorId, UniqueId checkpointId) { } } - @Test public void testActorCheckpointing() throws IOException, InterruptedException { TestUtils.skipTestUnderSingleProcess(); ActorCreationOptions options = diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index cfe5382530f45..67a0232424163 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -18,6 +18,7 @@ import org.testng.Assert; import org.testng.annotations.Test; +@Test(groups = {"directCall"}) public class ActorTest extends BaseTest { @RayRemote @@ -48,7 +49,6 @@ public int accessLargeObject(LargeObject largeObject) { } } - @Test public void testCreateAndCallActor() { // Test creating an actor from a constructor RayActor actor = Ray.createActor(Counter::new, 1); @@ -59,7 +59,6 @@ public void testCreateAndCallActor() { Assert.assertEquals(Integer.valueOf(3), Ray.call(Counter::increaseAndGet, actor, 1).get()); } - @Test public void testCallActorWithLargeObject() { RayActor actor = Ray.createActor(Counter::new, 1); LargeObject largeObject = new LargeObject(); @@ -68,11 +67,10 @@ public void testCallActorWithLargeObject() { } @RayRemote - public static Counter factory(int initValue) { + static Counter factory(int initValue) { return new Counter(initValue); } - @Test public void testCreateActorFromFactory() { // Test creating an actor from a factory method RayActor actor = Ray.createActor(ActorTest::factory, 1); @@ -82,24 +80,23 @@ public void testCreateActorFromFactory() { } @RayRemote - public static int testActorAsFirstParameter(RayActor actor, int delta) { + static int testActorAsFirstParameter(RayActor actor, int delta) { RayObject res = Ray.call(Counter::increaseAndGet, actor, delta); return res.get(); } @RayRemote - public static int testActorAsSecondParameter(int delta, RayActor actor) { + static int testActorAsSecondParameter(int delta, RayActor actor) { RayObject res = Ray.call(Counter::increaseAndGet, actor, delta); return res.get(); } @RayRemote - public static int testActorAsFieldOfParameter(List> actor, int delta) { + static int testActorAsFieldOfParameter(List> actor, int delta) { RayObject res = Ray.call(Counter::increaseAndGet, actor.get(0), delta); return res.get(); } - @Test public void testPassActorAsParameter() { RayActor actor = Ray.createActor(Counter::new, 0); Assert.assertEquals(Integer.valueOf(1), @@ -111,7 +108,6 @@ public void testPassActorAsParameter() { .get()); } - @Test public void testForkingActorHandle() { TestUtils.skipTestUnderSingleProcess(); RayActor counter = Ray.createActor(Counter::new, 100); @@ -120,7 +116,6 @@ public void testForkingActorHandle() { Assert.assertEquals(Integer.valueOf(103), Ray.call(Counter::increaseAndGet, counter2, 2).get()); } - @Test public void testUnreconstructableActorObject() throws InterruptedException { TestUtils.skipTestUnderSingleProcess(); // The UnreconstructableException is created by raylet. diff --git a/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java b/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java index 939372b96ad7b..4499d141efab8 100644 --- a/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java +++ b/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java @@ -44,7 +44,7 @@ private boolean executeCommand(List command, int waitTimeoutSeconds, } } - @BeforeClass + @BeforeClass(alwaysRun = true) public void setUp() { if (!"1".equals(System.getenv("ENABLE_MULTI_LANGUAGE_TESTS"))) { LOGGER.info("Skip Multi-language tests because environment variable " @@ -92,7 +92,7 @@ protected Map getRayStartEnv() { return ImmutableMap.of(); } - @AfterClass + @AfterClass(alwaysRun = true) public void tearDown() { // Disconnect to the cluster. Ray.shutdown(); diff --git a/java/test/src/main/java/org/ray/api/test/BaseTest.java b/java/test/src/main/java/org/ray/api/test/BaseTest.java index 4c3973064e32b..fa1d078de460b 100644 --- a/java/test/src/main/java/org/ray/api/test/BaseTest.java +++ b/java/test/src/main/java/org/ray/api/test/BaseTest.java @@ -16,7 +16,7 @@ public class BaseTest { private List filesToDelete; - @BeforeMethod + @BeforeMethod(alwaysRun = true) public void setUpBase(Method method) { LOGGER.info("===== Running test: " + method.getDeclaringClass().getName() + "." + method.getName()); @@ -34,7 +34,7 @@ public void setUpBase(Method method) { filesToDelete.forEach(File::deleteOnExit); } - @AfterMethod + @AfterMethod(alwaysRun = true) public void tearDownBase() { Ray.shutdown(); diff --git a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java index bfc996349ce00..6ad2871eae0bd 100644 --- a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java +++ b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java @@ -46,7 +46,7 @@ public void testCallingPythonFunction() { Assert.assertEquals(res.get(), "Response from Python: hello".getBytes()); } - @Test + @Test(groups = {"directCall"}) public void testCallingPythonActor() { // Direct actor call only allows passing arguments as values. // However, bytes arguments are passed from Java to Python as references. diff --git a/java/test/src/main/java/org/ray/api/test/FailureTest.java b/java/test/src/main/java/org/ray/api/test/FailureTest.java index 99a33cb6656d0..96a774dc4cb24 100644 --- a/java/test/src/main/java/org/ray/api/test/FailureTest.java +++ b/java/test/src/main/java/org/ray/api/test/FailureTest.java @@ -20,16 +20,16 @@ public class FailureTest extends BaseTest { private static final String EXCEPTION_MESSAGE = "Oops"; - public static int badFunc() { + static int badFunc() { throw new RuntimeException(EXCEPTION_MESSAGE); } - public static int badFunc2() { + static int badFunc2() { System.exit(-1); return 0; } - public static int slowFunc() { + static int slowFunc() { try { Thread.sleep(10000); } catch (InterruptedException e) { @@ -76,14 +76,14 @@ public void testNormalTaskFailure() { assertTaskFailedWithRayTaskException(Ray.call(FailureTest::badFunc)); } - @Test + @Test(groups = {"directCall"}) public void testActorCreationFailure() { TestUtils.skipTestUnderSingleProcess(); RayActor actor = Ray.createActor(BadActor::new, true); assertTaskFailedWithRayTaskException(Ray.call(BadActor::badMethod, actor)); } - @Test + @Test(groups = {"directCall"}) public void testActorTaskFailure() { TestUtils.skipTestUnderSingleProcess(); RayActor actor = Ray.createActor(BadActor::new, false); @@ -102,7 +102,7 @@ public void testWorkerProcessDying() { } } - @Test + @Test(groups = {"directCall"}) public void testActorProcessDying() { TestUtils.skipTestUnderSingleProcess(); // This test case hangs if the worker to worker connection is implemented with grpc. diff --git a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java index 86fdc1b94f92f..12df318e4453f 100644 --- a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java +++ b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java @@ -21,7 +21,7 @@ import org.testng.Assert; import org.testng.annotations.Test; - +@Test(groups = {"directCall"}) public class MultiThreadingTest extends BaseTest { private static final Logger LOGGER = LoggerFactory.getLogger(MultiThreadingTest.class); @@ -30,7 +30,7 @@ public class MultiThreadingTest extends BaseTest { private static final int NUM_THREADS = 20; @RayRemote - public static Integer echo(int num) { + static Integer echo(int num) { return num; } @@ -70,7 +70,7 @@ public ActorId getCurrentActorId() { } } - public static String testMultiThreading() { + static String testMultiThreading() { Random random = new Random(); // Test calling normal functions. runTestCaseInMultipleThreads(() -> { @@ -120,12 +120,10 @@ public static String testMultiThreading() { return "ok"; } - @Test public void testInDriver() { testMultiThreading(); } - @Test public void testInWorker() { // Single-process mode doesn't have real workers. TestUtils.skipTestUnderSingleProcess(); @@ -133,7 +131,6 @@ public void testInWorker() { Assert.assertEquals("ok", obj.get()); } - @Test public void testGetCurrentActorId() { TestUtils.skipTestUnderSingleProcess(); RayActor actorIdTester = Ray.createActor(ActorIdTester::new); diff --git a/java/test/src/main/java/org/ray/api/test/StressTest.java b/java/test/src/main/java/org/ray/api/test/StressTest.java index e2efecbf222e1..334a94199e5b2 100644 --- a/java/test/src/main/java/org/ray/api/test/StressTest.java +++ b/java/test/src/main/java/org/ray/api/test/StressTest.java @@ -72,7 +72,7 @@ public int ping(int n) { } } - @Test + @Test(groups = {"directCall"}) public void testSubmittingManyTasksToOneActor() { TestUtils.skipTestUnderSingleProcess(); RayActor actor = Ray.createActor(Actor::new); @@ -86,7 +86,6 @@ public void testSubmittingManyTasksToOneActor() { } } - @Test public void testPuttingAndGettingManyObjects() { TestUtils.skipTestUnderSingleProcess(); Integer objectToPut = 1; diff --git a/java/testng.xml b/java/testng.xml deleted file mode 100644 index c2454557226b1..0000000000000 --- a/java/testng.xml +++ /dev/null @@ -1,10 +0,0 @@ - - - - - - - - - - From 534efdf54ef19e2afcf7a1730fb96c84f195430e Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sat, 7 Sep 2019 11:27:44 +0800 Subject: [PATCH 18/24] Use IAlterSuiteListener --- java/BUILD.bazel | 5 ++- java/dependencies.bzl | 2 +- java/runtime/pom.xml | 2 +- java/streaming/pom.xml | 2 +- java/test.sh | 6 +-- java/test/pom.xml | 2 +- .../org/ray/api/RayAlterSuiteListener.java | 23 +++++++++++ .../src/main/java/org/ray/api/RunTestNG.java | 39 ------------------- java/testng.xml | 13 +++++++ 9 files changed, 47 insertions(+), 47 deletions(-) create mode 100644 java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java delete mode 100644 java/test/src/main/java/org/ray/api/RunTestNG.java create mode 100644 java/testng.xml diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 99b51b1b9272c..37ef5b93b6eed 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -2,6 +2,7 @@ load("//bazel:ray.bzl", "define_java_module") load("@build_stack_rules_proto//java:java_proto_compile.bzl", "java_proto_compile") exports_files([ + "testng.xml", "checkstyle.xml", "checkstyle-suppressions.xml", "streaming/testng.xml", @@ -130,7 +131,9 @@ define_java_module( java_binary( name = "all_tests", - main_class = "org.ray.api.RunTestNG", + main_class = "org.testng.TestNG", + data = ["testng.xml"], + args = ["java/testng.xml"], runtime_deps = [ ":org_ray_ray_test", ":org_ray_ray_runtime_test", diff --git a/java/dependencies.bzl b/java/dependencies.bzl index 26e36dff5a1c3..c51a181ed0e6f 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -16,7 +16,7 @@ def gen_java_deps(): "org.apache.commons:commons-lang3:3.4", "org.ow2.asm:asm:6.0", "org.slf4j:slf4j-log4j12:1.7.25", - "org.testng:testng:6.9.9", + "org.testng:testng:6.9.10", "redis.clients:jedis:2.8.0", ], repositories = [ diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index 3c40f7ffc54c8..eb6c268f84557 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -75,7 +75,7 @@ org.testng testng - 6.9.9 + 6.9.10 redis.clients diff --git a/java/streaming/pom.xml b/java/streaming/pom.xml index 382233fb02af4..e624bd6e53ae4 100644 --- a/java/streaming/pom.xml +++ b/java/streaming/pom.xml @@ -50,7 +50,7 @@ org.testng testng - 6.9.9 + 6.9.10 diff --git a/java/test.sh b/java/test.sh index 7dae453af4f82..bc74a4072ab75 100755 --- a/java/test.sh +++ b/java/test.sh @@ -25,14 +25,14 @@ echo "Running tests under cluster mode." # TODO(hchen): Ideally, we should use the following bazel command to run Java tests. However, if there're skipped tests, # TestNG will exit with code 2. And bazel treats it as test failure. # bazel test //java:all_tests --action_env=ENABLE_MULTI_LANGUAGE_TESTS=1 --test_output="errors" || cluster_exit_code=$? -ENABLE_MULTI_LANGUAGE_TESTS=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.ray.api.RunTestNG +ENABLE_MULTI_LANGUAGE_TESTS=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml echo "Running tests under cluster mode with direct actor call turned on." -ENABLE_MULTI_LANGUAGE_TESTS=1 ACTOR_CREATION_OPTIONS_DEFAULT_IS_DIRECT_CALL=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.ray.api.RunTestNG +ENABLE_MULTI_LANGUAGE_TESTS=1 ACTOR_CREATION_OPTIONS_DEFAULT_IS_DIRECT_CALL=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml echo "Running tests under single-process mode." # bazel test //java:all_tests --jvmopt="-Dray.run-mode=SINGLE_PROCESS" --test_output="errors" || single_exit_code=$? -run_testng java -Dray.run-mode="SINGLE_PROCESS" -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.ray.api.RunTestNG +run_testng java -Dray.run-mode="SINGLE_PROCESS" -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml echo "Running streaming tests." run_testng java -cp $ROOT_DIR/../bazel-bin/java/streaming_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/streaming/testng.xml diff --git a/java/test/pom.xml b/java/test/pom.xml index 6a3a31d2032e5..3dfbbbae8221c 100644 --- a/java/test/pom.xml +++ b/java/test/pom.xml @@ -65,7 +65,7 @@ org.testng testng - 6.9.9 + 6.9.10 diff --git a/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java b/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java new file mode 100644 index 0000000000000..6839c573134ba --- /dev/null +++ b/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java @@ -0,0 +1,23 @@ +package org.ray.api; + +import java.util.List; +import org.ray.api.options.ActorCreationOptions; +import org.testng.IAlterSuiteListener; +import org.testng.xml.XmlGroups; +import org.testng.xml.XmlRun; +import org.testng.xml.XmlSuite; + +public class RayAlterSuiteListener implements IAlterSuiteListener { + + @Override + public void alter(List suites) { + XmlSuite suite = suites.get(0); + if (ActorCreationOptions.DEFAULT_IS_DIRECT_CALL) { + XmlGroups groups = new XmlGroups(); + XmlRun run = new XmlRun(); + run.onInclude("directCall"); + groups.setRun(run); + suite.setGroups(groups); + } + } +} \ No newline at end of file diff --git a/java/test/src/main/java/org/ray/api/RunTestNG.java b/java/test/src/main/java/org/ray/api/RunTestNG.java deleted file mode 100644 index b1d85d262cbab..0000000000000 --- a/java/test/src/main/java/org/ray/api/RunTestNG.java +++ /dev/null @@ -1,39 +0,0 @@ -package org.ray.api; - -import com.google.common.collect.ImmutableList; -import java.util.ArrayList; -import java.util.List; -import org.ray.api.options.ActorCreationOptions; -import org.testng.TestNG; -import org.testng.xml.XmlGroups; -import org.testng.xml.XmlPackage; -import org.testng.xml.XmlRun; -import org.testng.xml.XmlSuite; -import org.testng.xml.XmlTest; - -public class RunTestNG { - - public static void main(String args[]) { - TestNG testng = new TestNG(); - XmlSuite suite = new XmlSuite(); - suite.setName("RAY suite"); - suite.setVerbose(2); - XmlTest test = new XmlTest(suite); - test.setName("RAY test"); - List packages = new ArrayList<>(); - packages.add(new XmlPackage("org.ray.api.test.*")); - packages.add(new XmlPackage("org.ray.runtime.*")); - test.setPackages(packages); - if (ActorCreationOptions.DEFAULT_IS_DIRECT_CALL) { - XmlGroups groups = new XmlGroups(); - XmlRun run = new XmlRun(); - run.onInclude("directCall"); - groups.setRun(run); - test.setGroups(groups); - } - testng.setXmlSuites(ImmutableList.of(suite)); - testng.setOutputDirectory("/tmp/ray_java_test_output"); - testng.run(); - System.exit(testng.getStatus()); - } -} diff --git a/java/testng.xml b/java/testng.xml new file mode 100644 index 0000000000000..9448cb30ffa55 --- /dev/null +++ b/java/testng.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + + From 961357522af9b7b7b3229ccc03b3c94119282531 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sat, 7 Sep 2019 12:05:33 +0800 Subject: [PATCH 19/24] Minor fix about missing a test --- .../src/main/java/org/ray/api/RayAlterSuiteListener.java | 2 +- java/test/src/main/java/org/ray/api/test/FailureTest.java | 6 +++--- java/test/src/main/java/org/ray/api/test/StressTest.java | 1 + 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java b/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java index 6839c573134ba..b90b9e4faa84b 100644 --- a/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java +++ b/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java @@ -20,4 +20,4 @@ public void alter(List suites) { suite.setGroups(groups); } } -} \ No newline at end of file +} diff --git a/java/test/src/main/java/org/ray/api/test/FailureTest.java b/java/test/src/main/java/org/ray/api/test/FailureTest.java index 96a774dc4cb24..f2fb8b96c2757 100644 --- a/java/test/src/main/java/org/ray/api/test/FailureTest.java +++ b/java/test/src/main/java/org/ray/api/test/FailureTest.java @@ -20,16 +20,16 @@ public class FailureTest extends BaseTest { private static final String EXCEPTION_MESSAGE = "Oops"; - static int badFunc() { + public static int badFunc() { throw new RuntimeException(EXCEPTION_MESSAGE); } - static int badFunc2() { + public static int badFunc2() { System.exit(-1); return 0; } - static int slowFunc() { + public static int slowFunc() { try { Thread.sleep(10000); } catch (InterruptedException e) { diff --git a/java/test/src/main/java/org/ray/api/test/StressTest.java b/java/test/src/main/java/org/ray/api/test/StressTest.java index 334a94199e5b2..e9e2614635e4d 100644 --- a/java/test/src/main/java/org/ray/api/test/StressTest.java +++ b/java/test/src/main/java/org/ray/api/test/StressTest.java @@ -86,6 +86,7 @@ public void testSubmittingManyTasksToOneActor() { } } + @Test public void testPuttingAndGettingManyObjects() { TestUtils.skipTestUnderSingleProcess(); Integer objectToPut = 1; From e3d77b746ca61535b19f6427a71aec6154c2ad1a Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sat, 7 Sep 2019 17:19:28 +0800 Subject: [PATCH 20/24] Rename is_direct_call to use_direct_call in Java --- .../ray/api/options/ActorCreationOptions.java | 18 +++++++++--------- java/test.sh | 2 +- .../org/ray/api/RayAlterSuiteListener.java | 2 +- .../src/main/java/org/ray/api/TestUtils.java | 2 +- .../ray/api/test/ActorReconstructionTest.java | 2 +- src/ray/core_worker/lib/java/jni_init.cc | 10 +++++----- src/ray/core_worker/lib/java/jni_utils.h | 8 ++++---- ...org_ray_runtime_task_NativeTaskSubmitter.cc | 12 ++++++------ 8 files changed, 28 insertions(+), 28 deletions(-) diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index 949c16ad1d22c..80dddf09af226 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -10,20 +10,20 @@ public class ActorCreationOptions extends BaseTaskOptions { public static final int NO_RECONSTRUCTION = 0; public static final int INFINITE_RECONSTRUCTIONS = (int) Math.pow(2, 30); - public static final boolean DEFAULT_IS_DIRECT_CALL = "1" - .equals(System.getenv("ACTOR_CREATION_OPTIONS_DEFAULT_IS_DIRECT_CALL")); + public static final boolean DEFAULT_USE_DIRECT_CALL = "1" + .equals(System.getenv("ACTOR_CREATION_OPTIONS_DEFAULT_USE_DIRECT_CALL")); public final int maxReconstructions; - public final boolean isDirectCall; + public final boolean useDirectCall; public final String jvmOptions; private ActorCreationOptions(Map resources, int maxReconstructions, - boolean isDirectCall, String jvmOptions) { + boolean useDirectCall, String jvmOptions) { super(resources); this.maxReconstructions = maxReconstructions; - this.isDirectCall = isDirectCall; + this.useDirectCall = useDirectCall; this.jvmOptions = jvmOptions; } @@ -34,7 +34,7 @@ public static class Builder { private Map resources = new HashMap<>(); private int maxReconstructions = NO_RECONSTRUCTION; - private boolean isDirectCall = DEFAULT_IS_DIRECT_CALL; + private boolean useDirectCall = DEFAULT_USE_DIRECT_CALL; private String jvmOptions = ""; public Builder setResources(Map resources) { @@ -49,8 +49,8 @@ public Builder setMaxReconstructions(int maxReconstructions) { // Since direct call is not fully supported yet, users are not allowed to set the option to true. // TODO (kfstorm): uncomment when direct call is ready. -// public Builder setIsDirectCall(boolean isDirectCall) { -// this.isDirectCall = isDirectCall; +// public Builder setUseDirectCall(boolean useDirectCall) { +// this.useDirectCall = useDirectCall; // return this; // } @@ -60,7 +60,7 @@ public Builder setJvmOptions(String jvmOptions) { } public ActorCreationOptions createActorCreationOptions() { - return new ActorCreationOptions(resources, maxReconstructions, isDirectCall, jvmOptions); + return new ActorCreationOptions(resources, maxReconstructions, useDirectCall, jvmOptions); } } diff --git a/java/test.sh b/java/test.sh index bc74a4072ab75..4612bf7e35b90 100755 --- a/java/test.sh +++ b/java/test.sh @@ -28,7 +28,7 @@ echo "Running tests under cluster mode." ENABLE_MULTI_LANGUAGE_TESTS=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml echo "Running tests under cluster mode with direct actor call turned on." -ENABLE_MULTI_LANGUAGE_TESTS=1 ACTOR_CREATION_OPTIONS_DEFAULT_IS_DIRECT_CALL=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml +ENABLE_MULTI_LANGUAGE_TESTS=1 ACTOR_CREATION_OPTIONS_DEFAULT_USE_DIRECT_CALL=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml echo "Running tests under single-process mode." # bazel test //java:all_tests --jvmopt="-Dray.run-mode=SINGLE_PROCESS" --test_output="errors" || single_exit_code=$? diff --git a/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java b/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java index b90b9e4faa84b..d5d042c1d1a71 100644 --- a/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java +++ b/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java @@ -12,7 +12,7 @@ public class RayAlterSuiteListener implements IAlterSuiteListener { @Override public void alter(List suites) { XmlSuite suite = suites.get(0); - if (ActorCreationOptions.DEFAULT_IS_DIRECT_CALL) { + if (ActorCreationOptions.DEFAULT_USE_DIRECT_CALL) { XmlGroups groups = new XmlGroups(); XmlRun run = new XmlRun(); run.onInclude("directCall"); diff --git a/java/test/src/main/java/org/ray/api/TestUtils.java b/java/test/src/main/java/org/ray/api/TestUtils.java index 788d0a07c153e..aff9ba0ef95d2 100644 --- a/java/test/src/main/java/org/ray/api/TestUtils.java +++ b/java/test/src/main/java/org/ray/api/TestUtils.java @@ -26,7 +26,7 @@ public static void skipTestUnderSingleProcess() { } public static void skipTestIfDirectActorCallEnabled() { - if (ActorCreationOptions.DEFAULT_IS_DIRECT_CALL) { + if (ActorCreationOptions.DEFAULT_USE_DIRECT_CALL) { throw new SkipException("This test doesn't work when direct actor call is enabled."); } } diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java index 4028030dbaa19..43ccfe0ff0f3d 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java @@ -65,7 +65,7 @@ public void testActorReconstruction() throws InterruptedException, IOException { // Try calling increase on this actor again and check the value is now 4. int value = Ray.call(Counter::increase, actor).get(); - Assert.assertEquals(value, options.isDirectCall ? 1 : 4); + Assert.assertEquals(value, options.useDirectCall ? 1 : 4); Assert.assertTrue(Ray.call(Counter::wasCurrentActorReconstructed, actor).get()); diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 6d63f92c87552..9bbfa3f9c4df6 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -49,9 +49,9 @@ jclass java_base_task_options_class; jfieldID java_base_task_options_resources; jclass java_actor_creation_options_class; -jfieldID java_actor_creation_options_default_is_direct_call; +jfieldID java_actor_creation_options_default_use_direct_call; jfieldID java_actor_creation_options_max_reconstructions; -jfieldID java_actor_creation_options_is_direct_call; +jfieldID java_actor_creation_options_use_direct_call; jfieldID java_actor_creation_options_jvm_options; jclass java_gcs_client_options_class; @@ -147,11 +147,11 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_actor_creation_options_class = LoadClass(env, "org/ray/api/options/ActorCreationOptions"); - java_actor_creation_options_default_is_direct_call = env->GetStaticFieldID( - java_actor_creation_options_class, "DEFAULT_IS_DIRECT_CALL", "Z"); + java_actor_creation_options_default_use_direct_call = env->GetStaticFieldID( + java_actor_creation_options_class, "DEFAULT_USE_DIRECT_CALL", "Z"); java_actor_creation_options_max_reconstructions = env->GetFieldID(java_actor_creation_options_class, "maxReconstructions", "I"); - java_actor_creation_options_is_direct_call = + java_actor_creation_options_use_direct_call = env->GetFieldID(java_actor_creation_options_class, "isDirectCall", "Z"); java_actor_creation_options_jvm_options = env->GetFieldID( java_actor_creation_options_class, "jvmOptions", "Ljava/lang/String;"); diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index 396c4301d6637..4905df9dfe7b2 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -91,12 +91,12 @@ extern jfieldID java_base_task_options_resources; /// ActorCreationOptions class extern jclass java_actor_creation_options_class; -/// DEFAULT_IS_DIRECT_CALL field of ActorCreationOptions class -extern jfieldID java_actor_creation_options_default_is_direct_call; +/// DEFAULT_USE_DIRECT_CALL field of ActorCreationOptions class +extern jfieldID java_actor_creation_options_default_use_direct_call; /// maxReconstructions field of ActorCreationOptions class extern jfieldID java_actor_creation_options_max_reconstructions; -/// isDirectCall field of ActorCreationOptions class -extern jfieldID java_actor_creation_options_is_direct_call; +/// useDirectCall field of ActorCreationOptions class +extern jfieldID java_actor_creation_options_use_direct_call; /// jvmOptions field of ActorCreationOptions class extern jfieldID java_actor_creation_options_jvm_options; diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc index dd5dc0610af3f..16fd0e1d19857 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc @@ -76,14 +76,14 @@ inline ray::TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject call inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, jobject actorCreationOptions) { uint64_t max_reconstructions = 0; - bool is_direct_call; + bool use_direct_call; std::unordered_map resources; std::vector dynamic_worker_options; if (actorCreationOptions) { max_reconstructions = static_cast(env->GetIntField( actorCreationOptions, java_actor_creation_options_max_reconstructions)); - is_direct_call = env->GetBooleanField(actorCreationOptions, - java_actor_creation_options_is_direct_call); + use_direct_call = env->GetBooleanField(actorCreationOptions, + java_actor_creation_options_use_direct_call); jobject java_resources = env->GetObjectField(actorCreationOptions, java_base_task_options_resources); resources = ToResources(env, java_resources); @@ -92,13 +92,13 @@ inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, java_actor_creation_options_jvm_options)); dynamic_worker_options.emplace_back(jvm_options); } else { - is_direct_call = + use_direct_call = env->GetStaticBooleanField(java_actor_creation_options_class, - java_actor_creation_options_default_is_direct_call); + java_actor_creation_options_default_use_direct_call); } ray::ActorCreationOptions action_creation_options{ - static_cast(max_reconstructions), is_direct_call, resources, + static_cast(max_reconstructions), use_direct_call, resources, dynamic_worker_options}; return action_creation_options; } From ec33ab88f8d085050b867250001ac6c277ac1754 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sat, 7 Sep 2019 17:32:48 +0800 Subject: [PATCH 21/24] Address comments --- .../java/org/ray/api/options/ActorCreationOptions.java | 5 ++++- src/ray/core_worker/context.cc | 6 ++++-- src/ray/core_worker/context.h | 4 ++-- src/ray/core_worker/transport/direct_actor_transport.cc | 2 +- src/ray/core_worker/transport/raylet_transport.cc | 2 +- src/ray/raylet/actor_registration.cc | 7 ++++--- 6 files changed, 16 insertions(+), 10 deletions(-) diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index 80dddf09af226..f150eaf983a66 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -10,6 +10,8 @@ public class ActorCreationOptions extends BaseTaskOptions { public static final int NO_RECONSTRUCTION = 0; public static final int INFINITE_RECONSTRUCTIONS = (int) Math.pow(2, 30); + // DO NOT set this environment variable. It's only used for test purposes. + // Please use `setUseDirectCall` instead. public static final boolean DEFAULT_USE_DIRECT_CALL = "1" .equals(System.getenv("ACTOR_CREATION_OPTIONS_DEFAULT_USE_DIRECT_CALL")); @@ -47,7 +49,8 @@ public Builder setMaxReconstructions(int maxReconstructions) { return this; } - // Since direct call is not fully supported yet, users are not allowed to set the option to true. + // Since direct call is not fully supported yet (see issue #5559), + // users are not allowed to set the option to true. // TODO (kfstorm): uncomment when direct call is ready. // public Builder setUseDirectCall(boolean useDirectCall) { // this.useDirectCall = useDirectCall; diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 9a8941624ce28..39d4ec3a5b2bc 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -80,7 +80,7 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { if (task_spec.IsActorCreationTask()) { RAY_CHECK(current_actor_id_.IsNil()); current_actor_id_ = task_spec.ActorCreationId(); - is_direct_call_actor_ = task_spec.IsDirectCall(); + current_actor_use_direct_call_ = task_spec.IsDirectCall(); } if (task_spec.IsActorTask()) { RAY_CHECK(current_actor_id_ == task_spec.ActorId()); @@ -92,7 +92,9 @@ std::shared_ptr WorkerContext::GetCurrentTask() const { const ActorID &WorkerContext::GetCurrentActorID() const { return current_actor_id_; } -bool WorkerContext::IsDirectCallActor() const { return is_direct_call_actor_; } +bool WorkerContext::CurrentActorUseDirectCall() const { + return current_actor_use_direct_call_; +} WorkerThreadContext &WorkerContext::GetThreadContext() { if (thread_context_ == nullptr) { diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 8317f10119e61..3c53e415e8b61 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -26,7 +26,7 @@ class WorkerContext { const ActorID &GetCurrentActorID() const; - bool IsDirectCallActor() const; + bool CurrentActorUseDirectCall() const; int GetNextTaskIndex(); @@ -46,7 +46,7 @@ class WorkerContext { ActorID current_actor_id_; /// Whether current actor accepts direct calls. - bool is_direct_call_actor_; + bool current_actor_use_direct_call_; private: static WorkerThreadContext &GetThreadContext(); diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index dd42432374369..9574d0a0024f8 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -247,7 +247,7 @@ void CoreWorkerDirectActorTaskReceiver::HandlePushTask( nullptr); return; } - if (task_spec.IsActorTask() && !worker_context_.IsDirectCallActor()) { + if (task_spec.IsActorTask() && !worker_context_.CurrentActorUseDirectCall()) { send_reply_callback(Status::Invalid("This actor doesn't accept direct calls."), nullptr, nullptr); return; diff --git a/src/ray/core_worker/transport/raylet_transport.cc b/src/ray/core_worker/transport/raylet_transport.cc index f99ba90e66dbe..f0f7a01cfce89 100644 --- a/src/ray/core_worker/transport/raylet_transport.cc +++ b/src/ray/core_worker/transport/raylet_transport.cc @@ -31,7 +31,7 @@ void CoreWorkerRayletTaskReceiver::HandleAssignTask( const Task task(request.task()); const auto &task_spec = task.GetTaskSpecification(); RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId(); - if (task_spec.IsActorTask() && worker_context_.IsDirectCallActor()) { + if (task_spec.IsActorTask() && worker_context_.CurrentActorUseDirectCall()) { send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr, nullptr); return; diff --git a/src/ray/raylet/actor_registration.cc b/src/ray/raylet/actor_registration.cc index 575ce095d6385..7574a57dbbe13 100644 --- a/src/ray/raylet/actor_registration.cc +++ b/src/ray/raylet/actor_registration.cc @@ -105,9 +105,10 @@ std::shared_ptr ActorRegistration::GenerateCheckpointData( const auto actor_handle_id = task->GetTaskSpecification().ActorHandleId(); const auto dummy_object = task->GetTaskSpecification().ActorDummyObject(); // Extend its frontier to include the most recent task. - // Note(hchen): this is needed because this method is called before - // `FinishAssignedTask`, which will be called when the worker tries to fetch - // the next task. + // NOTE(hchen): For non-direct-call actors, this is needed because this method is + // called before `FinishAssignedTask`, which will be called when the worker tries to + // fetch the next task. For direct-call actors, checkpoint data doesn't contain + // frontier info, so we don't need to do `ExtendFrontier` here. copy.ExtendFrontier(actor_handle_id, dummy_object); } From dc0a86f6744c2ca6a35469ab07eb6f8058279fb6 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sat, 7 Sep 2019 19:20:32 +0800 Subject: [PATCH 22/24] Fix --- src/ray/core_worker/lib/java/jni_init.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 9bbfa3f9c4df6..14c08ab10bbff 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -152,7 +152,7 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_actor_creation_options_max_reconstructions = env->GetFieldID(java_actor_creation_options_class, "maxReconstructions", "I"); java_actor_creation_options_use_direct_call = - env->GetFieldID(java_actor_creation_options_class, "isDirectCall", "Z"); + env->GetFieldID(java_actor_creation_options_class, "useDirectCall", "Z"); java_actor_creation_options_jvm_options = env->GetFieldID( java_actor_creation_options_class, "jvmOptions", "Ljava/lang/String;"); From 80e64ac354563ab039f7ae2581e77f0e41dfb21f Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sun, 8 Sep 2019 01:35:06 +0800 Subject: [PATCH 23/24] Fix test --- src/ray/raylet/worker.cc | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index eb48b92d53fad..947775d6d5130 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -128,13 +128,15 @@ void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set, task.GetTaskExecutionSpec().GetMessage()); request.set_resource_ids(resource_id_set.Serialize()); - auto status = rpc_client_->AssignTask( - request, [](Status status, const rpc::AssignTaskReply &reply) { - RAY_CHECK_OK_PREPEND(status, "Worker failed to finish executing task."); - // Worker has finished this task. There's nothing to do here - // and assigning new task will be done when raylet receives - // `TaskDone` message. - }); + auto status = rpc_client_->AssignTask(request, [](Status status, + const rpc::AssignTaskReply &reply) { + if (!status.ok()) { + RAY_LOG(ERROR) << "Worker failed to finish executing task: " << status.ToString(); + } + // Worker has finished this task. There's nothing to do here + // and assigning new task will be done when raylet receives + // `TaskDone` message. + }); finish_assign_callback(status); if (!status.ok()) { RAY_LOG(ERROR) << "Failed to assign task " << task.GetTaskSpecification().TaskId() From 99837ac46ffd87de4d0898c08e30bc5812e9eab1 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sun, 8 Sep 2019 12:24:17 +0800 Subject: [PATCH 24/24] Update ArgumentsBuilder --- .../main/java/org/ray/runtime/task/ArgumentsBuilder.java | 6 ------ .../java/org/ray/api/test/CrossLanguageInvocationTest.java | 4 +--- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index a0f8821349111..07ae3dfb1486c 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -36,12 +36,6 @@ public static List wrap(Object[] args, boolean isDirectCall) { "Passing RayObject to a direct call actor is not supported."); } id = ((RayObject) arg).getId(); - } else if (arg instanceof byte[]) { - // TODO (kfstorm): This could be supported once we supported passing by value with metadata. - if (isDirectCall) { - throw new IllegalArgumentException( - "Passing raw bytes to a direct call actor is not supported."); - } } else { value = ObjectSerializer.serialize(arg); if (!isDirectCall && value.data.length > LARGEST_SIZE_PASS_BY_VALUE) { diff --git a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java index 6ad2871eae0bd..0cde562b0ab0b 100644 --- a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java +++ b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java @@ -48,9 +48,7 @@ public void testCallingPythonFunction() { @Test(groups = {"directCall"}) public void testCallingPythonActor() { - // Direct actor call only allows passing arguments as values. - // However, bytes arguments are passed from Java to Python as references. - // TODO (kfstorm): This should be supported once passing by value with metadata is allowed. + // Python worker doesn't support direct call yet. TestUtils.skipTestIfDirectActorCallEnabled(); RayPyActor actor = Ray.createPyActor(PYTHON_MODULE, "Counter", "1".getBytes()); RayObject res = Ray.callPy(actor, "increase", "1".getBytes());