Skip to content

Commit

Permalink
Merge pull request facebookresearch#258 from ChrisCummins/action-proto
Browse files Browse the repository at this point in the history
[rpc] Add an Action protocol buffer.
  • Loading branch information
ChrisCummins authored May 11, 2021
2 parents 5d20b14 + 3a6f8a3 commit 77d28bc
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 10 deletions.
3 changes: 2 additions & 1 deletion compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
SessionNotFound,
)
from compiler_gym.service.proto import (
Action,
AddBenchmarkRequest,
EndSessionReply,
EndSessionRequest,
Expand Down Expand Up @@ -789,7 +790,7 @@ def step(self, action: Union[int, Iterable[int]]) -> StepType:
# Send the request to the backend service.
request = StepRequest(
session_id=self._session_id,
action=actions,
action=[Action(action=a) for a in actions],
observation_space=observation_indices,
)
try:
Expand Down
2 changes: 1 addition & 1 deletion compiler_gym/envs/llvm/service/LlvmSession.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ Status LlvmSession::step(const StepRequest& request, StepReply* reply) {
case LlvmActionSpace::PASSES_ALL:
for (int i = 0; i < request.action_size(); ++i) {
LlvmAction action;
RETURN_IF_ERROR(util::intToEnum(request.action(i), &action));
RETURN_IF_ERROR(util::intToEnum(request.action(i).action(), &action));
RETURN_IF_ERROR(runAction(action, reply));
}
}
Expand Down
2 changes: 2 additions & 0 deletions compiler_gym/service/proto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from compiler_gym.service.proto.compiler_gym_service_pb2 import (
Action,
ActionSpace,
AddBenchmarkReply,
AddBenchmarkRequest,
Expand Down Expand Up @@ -34,6 +35,7 @@
)

__all__ = [
"Action",
"ActionSpace",
"AddBenchmarkReply",
"AddBenchmarkRequest",
Expand Down
10 changes: 7 additions & 3 deletions compiler_gym/service/proto/compiler_gym_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ message StartSessionReply {
message StepRequest {
// The ID of the session.
int64 session_id = 1;
// A list of indices into the ActionSpace.action list. Actions are executed
// in the order they appear in this list.
repeated int32 action = 2;
// A list of actions to execute, in order.
repeated Action action = 2;
// A list of indices into the GetSpacesReply.observation_space_list
repeated int32 observation_space = 3;
}
Expand Down Expand Up @@ -133,6 +132,11 @@ message ActionSpace {
repeated string action = 2;
}

message Action {
// An index into the ActionSpace.action list.
int32 action = 1;
}

// ===========================================================================
// Observations.

Expand Down
2 changes: 1 addition & 1 deletion examples/RandomSearch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class Environment {
StepReply reply;

request.set_session_id(sessionId_);
request.add_action(static_cast<int>(action));
request.add_action()->set_action(static_cast<int>(action));
request.add_observation_space(static_cast<int>(observationSpace));
RETURN_IF_ERROR(service_.Step(nullptr, &request, &reply));
CHECK(reply.observation_size() == 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ ExampleCompilationSession::ExampleCompilationSession(const std::string& benchmar

Status ExampleCompilationSession::Step(const StepRequest* request, StepReply* reply) {
for (int i = 0; i < request->action_size(); ++i) {
const auto action = request->action(i);
const auto action = request->action(i).action();
// Run the actual action. Here we just range check.
RETURN_IF_ERROR(rangeCheck(action, 0, static_cast<int32_t>(actionSpace_.action_size() - 1)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def step(self, request: proto.StepRequest, context) -> proto.StepReply:
# Apply a list of actions from the user. Each value is an index into the
# ACTIONS_SPACE.action list.
for action in request.action:
logging.debug("Apply action %d", action)
if action < 0 or action >= len(ACTION_SPACE.action):
logging.debug("Apply action %d", action.action)
if action.action < 0 or action.action >= len(ACTION_SPACE.action):
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details("Out-of-range")
return
Expand Down
3 changes: 2 additions & 1 deletion tests/llvm/service/GvnSinkTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ TEST_F(GvnSinkTest, runGvnSinkOnBlowfish) {
std::nullopt, workingDirectory_);

StepRequest request;
request.add_action(static_cast<int>(LlvmAction::GVNSINK_PASS));
Action* action = request.add_action();
action->set_action(static_cast<int>(LlvmAction::GVNSINK_PASS)));
StepReply reply;
ASSERT_OK(env.Step(request, &reply));
}
Expand Down

0 comments on commit 77d28bc

Please sign in to comment.