From 3a6f8a3d399544286c70acb95722f21bec936b8e Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Tue, 11 May 2021 07:02:33 -0700 Subject: [PATCH] [rpc] Add an Action protocol buffer. This wraps the current integer-based 'action index' in an Action protocol buffer. This is to pave the way for adding support for more complex, non-categorical action spaces. Issue #52 --- compiler_gym/envs/compiler_env.py | 3 ++- compiler_gym/envs/llvm/service/LlvmSession.cc | 2 +- compiler_gym/service/proto/__init__.py | 2 ++ compiler_gym/service/proto/compiler_gym_service.proto | 10 +++++++--- examples/RandomSearch.cc | 2 +- .../service_cc/ExampleService.cc | 2 +- .../service_py/example_service.py | 4 ++-- tests/llvm/service/GvnSinkTest.cc | 3 ++- 8 files changed, 18 insertions(+), 10 deletions(-) diff --git a/compiler_gym/envs/compiler_env.py b/compiler_gym/envs/compiler_env.py index 7911deeee..6e29718bf 100644 --- a/compiler_gym/envs/compiler_env.py +++ b/compiler_gym/envs/compiler_env.py @@ -29,6 +29,7 @@ SessionNotFound, ) from compiler_gym.service.proto import ( + Action, AddBenchmarkRequest, EndSessionReply, EndSessionRequest, @@ -789,7 +790,7 @@ def step(self, action: Union[int, Iterable[int]]) -> StepType: # Send the request to the backend service. request = StepRequest( session_id=self._session_id, - action=actions, + action=[Action(action=a) for a in actions], observation_space=observation_indices, ) try: diff --git a/compiler_gym/envs/llvm/service/LlvmSession.cc b/compiler_gym/envs/llvm/service/LlvmSession.cc index 5a1c281f3..cd33d2ed2 100644 --- a/compiler_gym/envs/llvm/service/LlvmSession.cc +++ b/compiler_gym/envs/llvm/service/LlvmSession.cc @@ -151,7 +151,7 @@ Status LlvmSession::step(const StepRequest& request, StepReply* reply) { case LlvmActionSpace::PASSES_ALL: for (int i = 0; i < request.action_size(); ++i) { LlvmAction action; - RETURN_IF_ERROR(util::intToEnum(request.action(i), &action)); + RETURN_IF_ERROR(util::intToEnum(request.action(i).action(), &action)); RETURN_IF_ERROR(runAction(action, reply)); } } diff --git a/compiler_gym/service/proto/__init__.py b/compiler_gym/service/proto/__init__.py index 2db9c7f91..e818cce8a 100644 --- a/compiler_gym/service/proto/__init__.py +++ b/compiler_gym/service/proto/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from compiler_gym.service.proto.compiler_gym_service_pb2 import ( + Action, ActionSpace, AddBenchmarkReply, AddBenchmarkRequest, @@ -34,6 +35,7 @@ ) __all__ = [ + "Action", "ActionSpace", "AddBenchmarkReply", "AddBenchmarkRequest", diff --git a/compiler_gym/service/proto/compiler_gym_service.proto b/compiler_gym/service/proto/compiler_gym_service.proto index 7916af741..e9f6deb40 100644 --- a/compiler_gym/service/proto/compiler_gym_service.proto +++ b/compiler_gym/service/proto/compiler_gym_service.proto @@ -92,9 +92,8 @@ message StartSessionReply { message StepRequest { // The ID of the session. int64 session_id = 1; - // A list of indices into the ActionSpace.action list. Actions are executed - // in the order they appear in this list. - repeated int32 action = 2; + // A list of actions to execute, in order. + repeated Action action = 2; // A list of indices into the GetSpacesReply.observation_space_list repeated int32 observation_space = 3; } @@ -133,6 +132,11 @@ message ActionSpace { repeated string action = 2; } +message Action { + // An index into the ActionSpace.action list. + int32 action = 1; +} + // =========================================================================== // Observations. diff --git a/examples/RandomSearch.cc b/examples/RandomSearch.cc index d07e3cda9..d8d7505cb 100644 --- a/examples/RandomSearch.cc +++ b/examples/RandomSearch.cc @@ -89,7 +89,7 @@ class Environment { StepReply reply; request.set_session_id(sessionId_); - request.add_action(static_cast(action)); + request.add_action()->set_action(static_cast(action)); request.add_observation_space(static_cast(observationSpace)); RETURN_IF_ERROR(service_.Step(nullptr, &request, &reply)); CHECK(reply.observation_size() == 1); diff --git a/examples/example_compiler_gym_service/service_cc/ExampleService.cc b/examples/example_compiler_gym_service/service_cc/ExampleService.cc index 876338102..310f145fb 100644 --- a/examples/example_compiler_gym_service/service_cc/ExampleService.cc +++ b/examples/example_compiler_gym_service/service_cc/ExampleService.cc @@ -167,7 +167,7 @@ ExampleCompilationSession::ExampleCompilationSession(const std::string& benchmar Status ExampleCompilationSession::Step(const StepRequest* request, StepReply* reply) { for (int i = 0; i < request->action_size(); ++i) { - const auto action = request->action(i); + const auto action = request->action(i).action(); // Run the actual action. Here we just range check. RETURN_IF_ERROR(rangeCheck(action, 0, static_cast(actionSpace_.action_size() - 1))); } diff --git a/examples/example_compiler_gym_service/service_py/example_service.py b/examples/example_compiler_gym_service/service_py/example_service.py index 4c46b7d49..d483d4c7f 100755 --- a/examples/example_compiler_gym_service/service_py/example_service.py +++ b/examples/example_compiler_gym_service/service_py/example_service.py @@ -109,8 +109,8 @@ def step(self, request: proto.StepRequest, context) -> proto.StepReply: # Apply a list of actions from the user. Each value is an index into the # ACTIONS_SPACE.action list. for action in request.action: - logging.debug("Apply action %d", action) - if action < 0 or action >= len(ACTION_SPACE.action): + logging.debug("Apply action %d", action.action) + if action.action < 0 or action.action >= len(ACTION_SPACE.action): context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("Out-of-range") return diff --git a/tests/llvm/service/GvnSinkTest.cc b/tests/llvm/service/GvnSinkTest.cc index 2b25a216a..feaf78b92 100644 --- a/tests/llvm/service/GvnSinkTest.cc +++ b/tests/llvm/service/GvnSinkTest.cc @@ -47,7 +47,8 @@ TEST_F(GvnSinkTest, runGvnSinkOnBlowfish) { std::nullopt, workingDirectory_); StepRequest request; - request.add_action(static_cast(LlvmAction::GVNSINK_PASS)); + Action* action = request.add_action(); + action->set_action(static_cast(LlvmAction::GVNSINK_PASS))); StepReply reply; ASSERT_OK(env.Step(request, &reply)); }