Skip to content

Commit

Permalink
Fix load code from local (#12102)
Browse files Browse the repository at this point in the history
  • Loading branch information
fyrestone authored Mar 24, 2021
1 parent 898243d commit 52cfa1c
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 65 deletions.
2 changes: 1 addition & 1 deletion dashboard/actor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def construct_actor_groups(actors):
def actor_classname_from_task_spec(task_spec):
return task_spec.get("functionDescriptor", {})\
.get("pythonFunctionDescriptor", {})\
.get("className", "Unknown actor class")
.get("className", "Unknown actor class").split(".")[-1]


def _group_actors_by_python_class(actors):
Expand Down
26 changes: 18 additions & 8 deletions python/ray/_private/function_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ def _load_function_from_local(self, job_id, function_descriptor):
)
try:
module = importlib.import_module(module_name)
function = getattr(module, function_name)._function
parts = [part for part in function_name.split(".") if part]
object = module
for part in parts:
object = getattr(object, part)
function = object._function
self._function_execution_info[job_id][function_id] = (
FunctionExecutionInfo(
function=function,
Expand All @@ -278,7 +282,8 @@ def _load_function_from_local(self, job_id, function_descriptor):
self._num_task_executions[job_id][function_id] = 0
except Exception as e:
raise RuntimeError(f"Function {function_descriptor} failed "
"to be loaded from local code. "
"to be loaded from local code.\n"
f"sys.path: {sys.path}, "
f"Error message: {str(e)}")

def _wait_for_function(self, function_descriptor, job_id, timeout=10):
Expand Down Expand Up @@ -356,7 +361,8 @@ def export_actor_class(self, Class, actor_creation_function_descriptor,
key = (b"ActorClass:" + job_id.binary() + b":" +
actor_creation_function_descriptor.function_id.binary())
actor_class_info = {
"class_name": actor_creation_function_descriptor.class_name,
"class_name": actor_creation_function_descriptor.class_name.split(
".")[-1],
"module": actor_creation_function_descriptor.module_name,
"class": pickle.dumps(Class),
"job_id": job_id.binary(),
Expand Down Expand Up @@ -443,14 +449,18 @@ def _load_actor_class_from_local(self, job_id,
actor_creation_function_descriptor.class_name)
try:
module = importlib.import_module(module_name)
actor_class = getattr(module, class_name)
if isinstance(actor_class, ray.actor.ActorClass):
return actor_class.__ray_metadata__.modified_class
parts = [part for part in class_name.split(".") if part]
object = module
for part in parts:
object = getattr(object, part)
if isinstance(object, ray.actor.ActorClass):
return object.__ray_metadata__.modified_class
else:
return actor_class
return object
except Exception as e:
raise RuntimeError(
f"Actor {class_name} failed to be imported from local code."
f"Actor {class_name} failed to be imported from local code.\n"
f"sys.path: {sys.path}, "
f"Error Message: {str(e)}")

def _create_fake_actor_class(self, actor_class_name, actor_method_names):
Expand Down
1 change: 1 addition & 0 deletions python/ray/_raylet.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ cdef class CoreWorker:
object async_thread
object async_event_loop
object plasma_event_handler
object job_config
c_bool is_local_mode

cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata,
Expand Down
13 changes: 8 additions & 5 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1698,11 +1698,14 @@ cdef class CoreWorker:
CNodeID.FromBinary(client_id.binary()))

def get_job_config(self):
cdef CJobConfig c_job_config = \
CCoreWorkerProcess.GetCoreWorker().GetJobConfig()
job_config = ray.gcs_utils.JobConfig()
job_config.ParseFromString(c_job_config.SerializeAsString())
return job_config
cdef CJobConfig c_job_config
# We can cache the deserialized job config object here because
# the job config will not change after a job is submitted.
if self.job_config is None:
c_job_config = CCoreWorkerProcess.GetCoreWorker().GetJobConfig()
self.job_config = ray.gcs_utils.JobConfig()
self.job_config.ParseFromString(c_job_config.SerializeAsString())
return self.job_config

cdef void async_callback(shared_ptr[CRayObject] obj,
CObjectID object_ref,
Expand Down
27 changes: 23 additions & 4 deletions python/ray/includes/function_descriptor.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor):
Returns:
The FunctionDescriptor instance created according to the function.
"""
module_name = function.__module__
function_name = function.__name__
module_name = cls._get_module_name(function)
function_name = function.__qualname__
class_name = ""

pickled_function_hash = hashlib.shake_128(pickled_function).hexdigest(
Expand All @@ -207,8 +207,8 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor):
Returns:
The FunctionDescriptor instance created according to the class.
"""
module_name = target_class.__module__
class_name = target_class.__name__
module_name = cls._get_module_name(target_class)
class_name = target_class.__qualname__
# Use a random uuid as function hash to solve actor name conflict.
return cls(
module_name, "__init__", class_name,
Expand Down Expand Up @@ -283,6 +283,25 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor):
function_id = function_id_hash.digest(ray_constants.ID_SIZE)
return ray.FunctionID(function_id)

@staticmethod
def _get_module_name(object):
"""Get the module name from object. If the module is __main__,
get the module name from file.
Returns:
Module name of object.
"""
module_name = object.__module__
if module_name == "__main__":
try:
file_path = inspect.getfile(object)
n = inspect.getmodulename(file_path)
if n:
module_name = n
except TypeError:
pass
return module_name

def is_actor_method(self):
"""Wether this function descriptor is an actor method.
Expand Down
5 changes: 5 additions & 0 deletions python/ray/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def __init__(self,
self.num_java_workers_per_process = num_java_workers_per_process
self.jvm_options = jvm_options or []
self.code_search_path = code_search_path or []
# It's difficult to find the error that caused by the
# code_search_path is a string. So we assert here.
assert isinstance(self.code_search_path, (list, tuple)), \
f"The type of code search path is incorrect: " \
f"{type(code_search_path)}"
self.runtime_env = runtime_env or dict()

def serialize(self):
Expand Down
43 changes: 43 additions & 0 deletions python/ray/tests/test_basic_2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# coding: utf-8
import os
import logging
import sys
import threading
import time
import tempfile
import subprocess

import numpy as np
import pytest
Expand Down Expand Up @@ -644,6 +647,46 @@ def test_get_correct_node_ip():
assert found_ip == "10.0.0.111"


def test_load_code_from_local(ray_start_regular_shared):
# This case writes a driver python file to a temporary directory.
#
# The driver starts a cluster with
# `ray.init(ray.job_config.JobConfig(code_search_path=<path list>))`,
# then creates a nested actor. The actor will be loaded from code in
# worker.
#
# This tests the following two cases when :
# 1) Load a nested class.
# 2) Load a class defined in the `__main__` module.
code_test = """
import os
import ray
class A:
@ray.remote
class B:
def get(self):
return "OK"
if __name__ == "__main__":
current_path = os.path.dirname(__file__)
job_config = ray.job_config.JobConfig(code_search_path=[current_path])
ray.init({}, job_config=job_config)
b = A.B.remote()
print(ray.get(b.get.remote()))
"""

# Test code search path contains space.
with tempfile.TemporaryDirectory(suffix="a b") as tmpdir:
test_driver = os.path.join(tmpdir, "test_load_code_from_local.py")
with open(test_driver, "w") as f:
f.write(
code_test.format(
repr(ray_start_regular_shared["redis_address"])))
output = subprocess.check_output([sys.executable, test_driver])
assert b"OK" in output


if __name__ == "__main__":
import pytest
# Skip test_basic_2_client_mode for now- the test suite is breaking.
Expand Down
29 changes: 12 additions & 17 deletions python/ray/workers/default_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,6 @@
type=str,
default="",
help="The configuration of object spilling. Only used by I/O workers.")
parser.add_argument(
"--code-search-path",
default=None,
type=str,
help="A list of directories or jar files separated by colon that specify "
"the search path for user code. This will be used as `CLASSPATH` in "
"Java and `PYTHONPATH` in Python.")
parser.add_argument(
"--logging-rotate-bytes",
required=False,
Expand Down Expand Up @@ -156,16 +149,6 @@
if raylet_ip_address is None:
raylet_ip_address = args.node_ip_address

code_search_path = args.code_search_path
load_code_from_local = False
if code_search_path is not None:
load_code_from_local = True
for p in code_search_path.split(":"):
if os.path.isfile(p):
p = os.path.dirname(p)
sys.path.append(p)
ray.worker.global_worker.set_load_code_from_local(load_code_from_local)

ray_params = RayParams(
node_ip_address=args.node_ip_address,
raylet_ip_address=raylet_ip_address,
Expand All @@ -187,6 +170,18 @@
ray.worker._global_node = node
ray.worker.connect(node, mode=mode)

# Add code search path to sys.path, set load_code_from_local.
core_worker = ray.worker.global_worker.core_worker
code_search_path = core_worker.get_job_config().code_search_path
load_code_from_local = False
if code_search_path:
load_code_from_local = True
for p in code_search_path:
if os.path.isfile(p):
p = os.path.dirname(p)
sys.path.insert(0, p)
ray.worker.global_worker.set_load_code_from_local(load_code_from_local)

# Setup log file.
out_file, err_file = node.get_log_file_handles(
get_worker_log_file_name(args.worker_type))
Expand Down
7 changes: 6 additions & 1 deletion src/ray/common/function_descriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,12 @@ class PythonFunctionDescriptor : public FunctionDescriptorInterface {
virtual std::string CallString() const {
const std::string &class_name = typed_message_->class_name();
const std::string &function_name = typed_message_->function_name();
return class_name.empty() ? function_name : class_name + "." + function_name;
if (class_name.empty()) {
return function_name.substr(function_name.find_last_of(".") + 1);
} else {
return class_name.substr(class_name.find_last_of(".") + 1) + "." +
function_name.substr(function_name.find_last_of(".") + 1);
}
}

const std::string &ModuleName() const { return typed_message_->module_name(); }
Expand Down
46 changes: 17 additions & 29 deletions src/ray/raylet/worker_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,41 +182,29 @@ Process WorkerPool::StartWorkerProcess(
}
}

if (job_config) {
// Note that we push the item to the front of the vector to make
// sure this is the freshest option than others.
if (!job_config->jvm_options().empty()) {
dynamic_options.insert(dynamic_options.begin(), job_config->jvm_options().begin(),
job_config->jvm_options().end());
}

std::string code_search_path_str;
for (int i = 0; i < job_config->code_search_path_size(); i++) {
auto path = job_config->code_search_path(i);
if (i != 0) {
code_search_path_str += ":";
if (language == Language::JAVA) {
if (job_config) {
// Note that we push the item to the front of the vector to make
// sure this is the freshest option than others.
if (!job_config->jvm_options().empty()) {
dynamic_options.insert(dynamic_options.begin(), job_config->jvm_options().begin(),
job_config->jvm_options().end());
}
code_search_path_str += path;
}
if (!code_search_path_str.empty()) {
switch (language) {
case Language::PYTHON: {
code_search_path_str = "--code-search-path=" + code_search_path_str;
break;

std::string code_search_path_str;
for (int i = 0; i < job_config->code_search_path_size(); i++) {
auto path = job_config->code_search_path(i);
if (i != 0) {
code_search_path_str += ":";
}
code_search_path_str += path;
}
case Language::JAVA: {
if (!code_search_path_str.empty()) {
code_search_path_str = "-Dray.job.code-search-path=" + code_search_path_str;
break;
dynamic_options.push_back(code_search_path_str);
}
default:
RAY_LOG(FATAL) << "code_search_path is not supported for worker language "
<< language;
}
dynamic_options.push_back(code_search_path_str);
}
}

if (language == Language::JAVA) {
dynamic_options.push_back("-Dray.job.num-java-workers-per-process=" +
std::to_string(workers_to_start));
}
Expand Down

0 comments on commit 52cfa1c

Please sign in to comment.