diff --git a/compiler_gym/service/runtime/BUILD b/compiler_gym/service/runtime/BUILD index 5c75bc297e..05fac86e87 100644 --- a/compiler_gym/service/runtime/BUILD +++ b/compiler_gym/service/runtime/BUILD @@ -13,7 +13,16 @@ py_library( srcs = ["__init__.py"], visibility = ["//visibility:public"], deps = [ - ":benchmark_cache", + ":create_and_run_compiler_gym_service", + ], +) + +cc_library( + name = "cc_runtime", + hdrs = ["Runtime.h"], + visibility = ["//visibility:public"], + deps = [ + ":CreateAndRunCompilerGymServiceImpl", ], ) @@ -38,3 +47,68 @@ cc_library( "@glog", ], ) + +py_library( + name = "compiler_gym_service", + srcs = ["compiler_gym_service.py"], + deps = [ + ":benchmark_cache", + "//compiler_gym/service:compilation_session", + "//compiler_gym/service/proto", + "//compiler_gym/util", + ], +) + +cc_library( + name = "CompilerGymService", + hdrs = [ + "CompilerGymService.h", + "CompilerGymServiceImpl.h", + ], + visibility = ["//tests/service/runtime:__subpackages__"], + deps = [ + ":BenchmarkCache", + ":CompilerGymServiceImpl", + "//compiler_gym/service:CompilationSession", + "//compiler_gym/service/proto:compiler_gym_service_cc", + "//compiler_gym/service/proto:compiler_gym_service_cc_grpc", + "@boost//:filesystem", + "@com_github_grpc_grpc//:grpc++", + ], +) + +cc_library( + name = "CompilerGymServiceImpl", + hdrs = ["CompilerGymServiceImpl.h"], + deps = [ + "//compiler_gym/util:GrpcStatusMacros", + "//compiler_gym/util:Version", + "@fmt", + "@glog", + ], +) + +py_library( + name = "create_and_run_compiler_gym_service", + srcs = ["create_and_run_compiler_gym_service.py"], + deps = [ + ":compiler_gym_service", + "//compiler_gym/service/proto", + "//compiler_gym/util", + ], +) + +cc_library( + name = "CreateAndRunCompilerGymServiceImpl", + srcs = ["CreateAndRunCompilerGymServiceImpl.cc"], + hdrs = ["CreateAndRunCompilerGymServiceImpl.h"], + deps = [ + ":CompilerGymService", + "//compiler_gym/util:GrpcStatusMacros", + "//compiler_gym/util:Unreachable", + "@boost//:filesystem", + "@com_github_grpc_grpc//:grpc++", + "@gflags", + "@glog", + ], +) diff --git a/compiler_gym/service/runtime/CompilerGymService.h b/compiler_gym/service/runtime/CompilerGymService.h new file mode 100644 index 0000000000..106c3f76fe --- /dev/null +++ b/compiler_gym/service/runtime/CompilerGymService.h @@ -0,0 +1,88 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. +#pragma once + +#include + +#include +#include + +#include "boost/filesystem.hpp" +#include "compiler_gym/service/CompilationSession.h" +#include "compiler_gym/service/proto/compiler_gym_service.grpc.pb.h" +#include "compiler_gym/service/proto/compiler_gym_service.pb.h" +#include "compiler_gym/service/runtime/BenchmarkCache.h" + +namespace compiler_gym::runtime { + +// A default implementation of the CompilerGymService. When parametrized by a +// CompilationSession subclass, this provides the RPC handling logic to run a +// gym service. +template +class CompilerGymService final : public compiler_gym::CompilerGymService::Service { + public: + CompilerGymService(const boost::filesystem::path& workingDirectory, + std::unique_ptr benchmarks = nullptr); + + // RPC endpoints. + grpc::Status GetVersion(grpc::ServerContext* context, const GetVersionRequest* request, + GetVersionReply* reply) final override; + + grpc::Status GetSpaces(grpc::ServerContext* context, const GetSpacesRequest* request, + GetSpacesReply* reply) final override; + + grpc::Status StartSession(grpc::ServerContext* context, const StartSessionRequest* request, + StartSessionReply* reply) final override; + + grpc::Status ForkSession(grpc::ServerContext* context, const ForkSessionRequest* request, + ForkSessionReply* reply) final override; + + grpc::Status EndSession(grpc::ServerContext* context, const EndSessionRequest* request, + EndSessionReply* reply) final override; + + // NOTE: Step() is not thread safe. The underlying assumption is that each + // CompilationSessionType is managed by a single thread, so race conditions + // between operations that affect the same CompilationSessionType are not + // protected against. + grpc::Status Step(grpc::ServerContext* context, const StepRequest* request, + StepReply* reply) final override; + + grpc::Status AddBenchmark(grpc::ServerContext* context, const AddBenchmarkRequest* request, + AddBenchmarkReply* reply) final override; + + inline BenchmarkCache& benchmarks() { return *benchmarks_; } + + protected: + [[nodiscard]] grpc::Status session(uint64_t id, CompilationSession** environment); + + [[nodiscard]] grpc::Status session(uint64_t id, const CompilationSession** environment) const; + + [[nodiscard]] grpc::Status action_space(const CompilationSession* session, int index, + const ActionSpace** actionSpace) const; + + [[nodiscard]] grpc::Status observation_space(const CompilationSession* session, int index, + const ObservationSpace** observationSpace) const; + + inline const boost::filesystem::path& workingDirectory() const { return workingDirectory_; } + + // Add the given session and return its ID. + uint64_t addSession(std::unique_ptr session); + + private: + const boost::filesystem::path workingDirectory_; + const std::vector actionSpaces_; + const std::vector observationSpaces_; + + std::unordered_map> sessions_; + std::unique_ptr benchmarks_; + + // Mutex used to ensure thread safety of creation and destruction of sessions. + std::mutex sessionsMutex_; + uint64_t nextSessionId_; +}; + +} // namespace compiler_gym::runtime + +#include "compiler_gym/service/runtime/CompilerGymServiceImpl.h" diff --git a/compiler_gym/service/runtime/CompilerGymServiceImpl.h b/compiler_gym/service/runtime/CompilerGymServiceImpl.h new file mode 100644 index 0000000000..4aee1c6d98 --- /dev/null +++ b/compiler_gym/service/runtime/CompilerGymServiceImpl.h @@ -0,0 +1,244 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the LICENSE file +// in the root directory of this source tree. +// +// Private implementation of the CompilerGymService template class. Do not +// include this header directly! Use +// compiler_gym/service/runtimeCompilerGymService.h. +#pragma once + +#include + +#include "compiler_gym/util/GrpcStatusMacros.h" +#include "compiler_gym/util/Version.h" + +namespace compiler_gym::runtime { + +template +CompilerGymService::CompilerGymService( + const boost::filesystem::path& workingDirectory, std::unique_ptr benchmarks) + : workingDirectory_(workingDirectory), + actionSpaces_(CompilationSessionType(workingDirectory).getActionSpaces()), + observationSpaces_(CompilationSessionType(workingDirectory).getObservationSpaces()), + benchmarks_(benchmarks ? std::move(benchmarks) : std::make_unique()), + nextSessionId_(0) {} + +template +grpc::Status CompilerGymService::GetVersion( + grpc::ServerContext* context, const GetVersionRequest* request, GetVersionReply* reply) { + VLOG(2) << "GetVersion()"; + reply->set_service_version(COMPILER_GYM_VERSION); + CompilationSessionType environment(workingDirectory()); + reply->set_compiler_version(environment.getCompilerVersion()); + return grpc::Status::OK; +} + +template +grpc::Status CompilerGymService::GetSpaces(grpc::ServerContext* context, + const GetSpacesRequest* request, + GetSpacesReply* reply) { + VLOG(2) << "GetSpaces()"; + *reply->mutable_action_space_list() = {actionSpaces_.begin(), actionSpaces_.end()}; + *reply->mutable_observation_space_list() = {observationSpaces_.begin(), observationSpaces_.end()}; + return grpc::Status::OK; +} + +template +grpc::Status CompilerGymService::StartSession( + grpc::ServerContext* context, const StartSessionRequest* request, StartSessionReply* reply) { + if (!request->benchmark().size()) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "No benchmark URI set for StartSession()"); + } + + VLOG(1) << "StartSession(" << request->benchmark() << "), [" << nextSessionId_ << "]"; + const std::lock_guard lock(sessionsMutex_); + + const Benchmark* benchmark = benchmarks().get(request->benchmark()); + if (!benchmark) { + return grpc::Status(grpc::StatusCode::NOT_FOUND, "Benchmark not found"); + } + + // Construct the new session. + auto environment = std::make_unique(workingDirectory()); + + // Resolve the action space. + const ActionSpace* actionSpace; + RETURN_IF_ERROR(action_space(environment.get(), request->action_space(), &actionSpace)); + + // Initialize the session. + RETURN_IF_ERROR(environment->init(*actionSpace, *benchmark)); + + // Compute the initial observations. + for (int i = 0; i < request->observation_space_size(); ++i) { + const ObservationSpace* observationSpace; + RETURN_IF_ERROR( + observation_space(environment.get(), request->observation_space(i), &observationSpace)); + RETURN_IF_ERROR(environment->computeObservation(*observationSpace, *reply->add_observation())); + } + + reply->set_session_id(addSession(std::move(environment))); + + return grpc::Status::OK; +} + +template +grpc::Status CompilerGymService::ForkSession( + grpc::ServerContext* context, const ForkSessionRequest* request, ForkSessionReply* reply) { + const std::lock_guard lock(sessionsMutex_); + + CompilationSession* baseSession; + RETURN_IF_ERROR(session(request->session_id(), &baseSession)); + VLOG(1) << "ForkSession(" << request->session_id() << "), [" << nextSessionId_ << "]"; + + // Construct the new session. + auto forked = std::make_unique(workingDirectory()); + + // Initialize from the base environment. + RETURN_IF_ERROR(forked->init(baseSession)); + + reply->set_session_id(addSession(std::move(forked))); + + return grpc::Status::OK; +} + +template +grpc::Status CompilerGymService::EndSession( + grpc::ServerContext* context, const EndSessionRequest* request, EndSessionReply* reply) { + VLOG(1) << "EndSession(" << request->session_id() << "), " << sessions_.size() - 1 + << " sessions remaining"; + + const std::lock_guard lock(sessionsMutex_); + + // Note that unlike the other methods, no error is thrown if the requested + // session does not exist. + if (sessions_.find(request->session_id()) != sessions_.end()) { + const CompilationSession* environment; + RETURN_IF_ERROR(session(request->session_id(), &environment)); + sessions_.erase(request->session_id()); + } + + reply->set_remaining_sessions(sessions_.size()); + return Status::OK; +} + +template +grpc::Status CompilerGymService::Step(grpc::ServerContext* context, + const StepRequest* request, + StepReply* reply) { + CompilationSession* environment; + RETURN_IF_ERROR(session(request->session_id(), &environment)); + + VLOG(2) << "Session " << request->session_id() << " Step()"; + + bool endOfEpisode = false; + std::optional newActionSpace; + bool actionsHadNoEffect = true; + + // Apply the actions. + for (int i = 0; i < request->action_size(); ++i) { + bool actionHadNoEffect = false; + std::optional newActionSpaceFromAction; + RETURN_IF_ERROR(environment->applyAction(request->action(i), endOfEpisode, + newActionSpaceFromAction, actionHadNoEffect)); + actionsHadNoEffect &= actionHadNoEffect; + if (newActionSpaceFromAction.has_value()) { + newActionSpace = *newActionSpaceFromAction; + } + if (endOfEpisode) { + break; + } + } + + // Compute the requested observations. + for (int i = 0; i < request->observation_space_size(); ++i) { + const ObservationSpace* observationSpace; + RETURN_IF_ERROR( + observation_space(environment, request->observation_space(i), &observationSpace)); + DCHECK(observationSpace) << "No observation space set"; + RETURN_IF_ERROR(environment->computeObservation(*observationSpace, *reply->add_observation())); + } + + // Call the end-of-step callback. + RETURN_IF_ERROR(environment->endOfStep(actionsHadNoEffect, endOfEpisode, newActionSpace)); + + reply->set_action_had_no_effect(actionsHadNoEffect); + if (newActionSpace.has_value()) { + *reply->mutable_new_action_space() = *newActionSpace; + } + reply->set_end_of_session(endOfEpisode); + return Status::OK; +} + +template +grpc::Status CompilerGymService::AddBenchmark( + grpc::ServerContext* context, const AddBenchmarkRequest* request, AddBenchmarkReply* reply) { + // We need to grab the sessions lock here to ensure thread safe access to the + // benchmarks cache. + const std::lock_guard lock(sessionsMutex_); + + VLOG(2) << "AddBenchmark()"; + for (int i = 0; i < request->benchmark_size(); ++i) { + benchmarks().add(std::move(request->benchmark(i))); + } + + return grpc::Status::OK; +} + +template +grpc::Status CompilerGymService::session(uint64_t id, + CompilationSession** environment) { + auto it = sessions_.find(id); + if (it == sessions_.end()) { + return Status(grpc::StatusCode::NOT_FOUND, fmt::format("Session not found: {}", id)); + } + + *environment = it->second.get(); + return grpc::Status::OK; +} + +template +grpc::Status CompilerGymService::session( + uint64_t id, const CompilationSession** environment) const { + auto it = sessions_.find(id); + if (it == sessions_.end()) { + return grpc::Status(grpc::StatusCode::NOT_FOUND, fmt::format("Session not found: {}", id)); + } + + *environment = it->second.get(); + return grpc::Status::OK; +} + +template +grpc::Status CompilerGymService::action_space( + const CompilationSession* session, int index, const ActionSpace** actionSpace) const { + if (index < 0 || index >= static_cast(actionSpaces_.size())) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + fmt::format("Action space index out of range: {}", index)); + } + *actionSpace = &actionSpaces_[index]; + return Status::OK; +} + +template +grpc::Status CompilerGymService::observation_space( + const CompilationSession* session, int index, const ObservationSpace** observationSpace) const { + if (index < 0 || index >= static_cast(observationSpaces_.size())) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + fmt::format("Observation space index out of range: {}", index)); + } + *observationSpace = &observationSpaces_[index]; + return Status::OK; +} + +template +uint64_t CompilerGymService::addSession( + std::unique_ptr session) { + uint64_t id = nextSessionId_; + sessions_[id] = std::move(session); + ++nextSessionId_; + return id; +} + +} // namespace compiler_gym::runtime diff --git a/compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.cc b/compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.cc new file mode 100644 index 0000000000..7ea3ebdd96 --- /dev/null +++ b/compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.cc @@ -0,0 +1,12 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. +#include "compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.h" + +DEFINE_string( + working_dir, "", + "The working directory to use. Must be an existing directory with write permissions."); +DEFINE_string(port, "0", + "The port to listen on. If 0, an unused port will be selected. The selected port is " + "written to /port.txt."); diff --git a/compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.h b/compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.h new file mode 100644 index 0000000000..92e666f298 --- /dev/null +++ b/compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.h @@ -0,0 +1,112 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. +// +// Private implementation of the createAndRunCompilerGymService(). Do not +// include this header directly! Use compiler_gym/service/runtime/Runtime.h. +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include "boost/filesystem.hpp" +#include "compiler_gym/service/proto/compiler_gym_service.pb.h" +#include "compiler_gym/service/runtime/CompilerGymService.h" +#include "compiler_gym/util/Unreachable.h" + +DECLARE_string(port); +DECLARE_string(working_dir); + +namespace compiler_gym::runtime { + +// Increase maximum message size beyond the 4MB default as inbound message +// may be larger (e.g., in the case of IR strings). +constexpr size_t kMaxMessageSizeInBytes = 512 * 1024 * 1024; + +// Create a service, configured using --port and --working_dir flags, and run +// it. This function never returns. +// +// CompilationService must be a valid compiler_gym::CompilationService subclass +// that implements the abstract methods and takes a single-argument working +// directory constructor: +// +// class MyCompilationService final : public CompilationService { +// public: +// ... +// } +// +// Usage: +// +// int main(int argc, char** argv) { +// createAndRunCompilerGymServiceImpl(argc, argv, "usage string"); +// } +template +[[noreturn]] void createAndRunCompilerGymServiceImpl(int argc, char** argv, const char* usage) { + gflags::SetUsageMessage(std::string(usage)); + // TODO: Fatal error if unparsed flags remain. + gflags::ParseCommandLineFlags(&argc, &argv, /*remove_flags=*/false); + + // TODO: Create a temporary working directory if --working_dir is not set. + CHECK(!FLAGS_working_dir.empty()) << "--working_dir flag not set"; + if (FLAGS_port.empty()) { + FLAGS_port = "0"; + } + + const boost::filesystem::path workingDirectory = FLAGS_working_dir; + FLAGS_log_dir = workingDirectory.string() + "/logs"; + + CHECK(boost::filesystem::is_directory(FLAGS_log_dir)) << "Directory not found: " << FLAGS_log_dir; + + google::InitGoogleLogging(argv[0]); + + CompilerGymService service{workingDirectory}; + + grpc::ServerBuilder builder; + builder.RegisterService(&service); + + builder.SetMaxMessageSize(kMaxMessageSizeInBytes); + + // Start a channel on the port. + int port; + std::string serverAddress = "0.0.0.0:" + FLAGS_port; + builder.AddListeningPort(serverAddress, grpc::InsecureServerCredentials(), &port); + + // Start the server. + std::unique_ptr server(builder.BuildAndStart()); + CHECK(server) << "Failed to build RPC service"; + + { + // Write the port to a /port.txt file, which an external + // process can read to determine how to get in touch. First write the port + // to a temporary file and rename it, since renaming is atomic. + const boost::filesystem::path portPath = workingDirectory / "port.txt"; + std::ofstream out(portPath.string() + ".tmp"); + out << std::to_string(port) << std::endl; + out.close(); + boost::filesystem::rename(portPath.string() + ".tmp", portPath); + } + + { + // Write the process ID to a /pid.txt file, which can + // external process can later use to determine if this service is still + // alive. + const boost::filesystem::path pidPath = workingDirectory / "pid.txt"; + std::ofstream out(pidPath.string() + ".tmp"); + out << std::to_string(getpid()) << std::endl; + out.close(); + boost::filesystem::rename(pidPath.string() + ".tmp", pidPath); + } + + LOG(INFO) << "Service " << workingDirectory << " listening on " << port << ", PID = " << getpid(); + + server->Wait(); + UNREACHABLE("grpc::Server::Wait() should not return"); +} + +} // namespace compiler_gym::runtime diff --git a/compiler_gym/service/runtime/Runtime.h b/compiler_gym/service/runtime/Runtime.h new file mode 100644 index 0000000000..f49d0caa42 --- /dev/null +++ b/compiler_gym/service/runtime/Runtime.h @@ -0,0 +1,17 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. +#pragma once + +#include "compiler_gym/service/runtime/CompilerGymService.h" +#include "compiler_gym/service/runtime/CreateAndRunCompilerGymServiceImpl.h" + +namespace compiler_gym::runtime { + +template +[[noreturn]] void createAndRunCompilerGymService(int argc, char** argv, const char* usage) { + createAndRunCompilerGymServiceImpl(argc, argv, usage); +} + +} // namespace compiler_gym::runtime diff --git a/compiler_gym/service/runtime/__init__.py b/compiler_gym/service/runtime/__init__.py index 6264236915..579dd8b878 100644 --- a/compiler_gym/service/runtime/__init__.py +++ b/compiler_gym/service/runtime/__init__.py @@ -2,3 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from compiler_gym.service.runtime.create_and_run_compiler_gym_service import ( + create_and_run_compiler_gym_service, +) + +__all__ = [ + "create_and_run_compiler_gym_service", +] diff --git a/compiler_gym/service/runtime/compiler_gym_service.py b/compiler_gym/service/runtime/compiler_gym_service.py new file mode 100644 index 0000000000..2a2bd8a0e9 --- /dev/null +++ b/compiler_gym/service/runtime/compiler_gym_service.py @@ -0,0 +1,171 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +from contextlib import contextmanager +from pathlib import Path +from threading import Lock +from typing import Dict + +from grpc import StatusCode + +from compiler_gym.service.compilation_session import CompilationSession +from compiler_gym.service.proto import AddBenchmarkReply, AddBenchmarkRequest +from compiler_gym.service.proto import ( + CompilerGymServiceServicer as CompilerGymServiceServicerStub, +) +from compiler_gym.service.proto import ( + EndSessionReply, + EndSessionRequest, + GetSpacesReply, + GetSpacesRequest, + GetVersionReply, + GetVersionRequest, + StartSessionReply, + StartSessionRequest, + StepReply, + StepRequest, +) +from compiler_gym.service.runtime.benchmark_cache import BenchmarkCache +from compiler_gym.util.version import __version__ + + +@contextmanager +def exception_to_grpc_status(context): + def handle_exception_as(exception, code): + context.set_code(code) + context.set_details(str(exception)) + + try: + yield + except ValueError as e: + handle_exception_as(e, StatusCode.INVALID_ARGUMENT) + except LookupError as e: + handle_exception_as(e, StatusCode.NOT_FOUND) + except NotImplementedError as e: + handle_exception_as(e, StatusCode.UNIMPLEMENTED) + except FileNotFoundError as e: + handle_exception_as(e, StatusCode.UNIMPLEMENTED) + except TypeError as e: + handle_exception_as(e, StatusCode.FAILED_PRECONDITION) + except TimeoutError as e: + handle_exception_as(e, StatusCode.DEADLINE_EXCEEDED) + + +class CompilerGymService(CompilerGymServiceServicerStub): + def __init__(self, working_directory: Path, compilation_session_type): + self.working_directory = working_directory + self.benchmarks = BenchmarkCache() + + self.compilation_session_type = compilation_session_type + self.sessions: Dict[int, CompilationSession] = {} + self.sessions_lock = Lock() + self.next_session_id: int = 0 + + self.action_spaces = compilation_session_type.action_spaces + self.observation_spaces = compilation_session_type.observation_spaces + + def GetVersion(self, request: GetVersionRequest, context) -> GetVersionReply: + del context # Unused + del request # Unused + logging.debug("GetVersion()") + return GetVersionReply( + service_version=__version__, + compiler_version=self.compilation_session_type.compiler_version, + ) + + def GetSpaces(self, request: GetSpacesRequest, context) -> GetSpacesReply: + del request # Unused + logging.debug("GetSpaces()") + with exception_to_grpc_status(context): + return GetSpacesReply( + action_space_list=self.action_spaces, + observation_space_list=self.observation_spaces, + ) + + def StartSession(self, request: StartSessionRequest, context) -> StartSessionReply: + """Create a new compilation session.""" + logging.debug("StartSession(%s), [%d]", request.benchmark, self.next_session_id) + reply = StartSessionReply() + + if not request.benchmark: + context.set_code(StatusCode.INVALID_ARGUMENT) + context.set_details("No benchmark URI set for StartSession()") + return reply + + with self.sessions_lock, exception_to_grpc_status(context): + if request.benchmark not in self.benchmarks: + context.set_code(StatusCode.NOT_FOUND) + context.set_details("Benchmark not found") + return reply + + session = self.compilation_session_type( + working_directory=self.working_directory, + action_space=self.action_spaces[request.action_space], + benchmark=self.benchmarks[request.benchmark], + ) + + # Generate the initial observations. + reply.observation.extend( + [ + session.get_observation(self.observation_spaces[obs]) + for obs in request.observation_space + ] + ) + + reply.session_id = self.next_session_id + self.sessions[reply.session_id] = session + self.next_session_id += 1 + + return reply + + def EndSession(self, request: EndSessionRequest, context) -> EndSessionReply: + del context # Unused + logging.debug( + "EndSession(%d), %d sessions remaining", + request.session_id, + len(self.sessions) - 1, + ) + + with self.sessions_lock: + if request.session_id in self.sessions: + del self.sessions[request.session_id] + return EndSessionReply(remaining_sessions=len(self.sessions)) + + def Step(self, request: StepRequest, context) -> StepReply: + logging.debug("Step()") + reply = StepReply() + + if request.session_id not in self.sessions: + context.set_code(StatusCode.NOT_FOUND) + context.set_details(f"Session not found: {request.session_id}") + return reply + + session = self.sessions[request.session_id] + + reply.action_had_no_effect = True + + with exception_to_grpc_status(context): + for action in request.action: + reply.end_of_session, nas, ahne = session.apply_action(action) + reply.action_had_no_effect &= ahne + if nas: + reply.new_action_space.CopyFrom(nas) + + reply.observation.extend( + [ + session.get_observation(self.observation_spaces[obs]) + for obs in request.observation_space + ] + ) + + return reply + + def AddBenchmark(self, request: AddBenchmarkRequest, context) -> AddBenchmarkReply: + del context # Unused + reply = AddBenchmarkReply() + with self.sessions_lock: + for benchmark in request.benchmark: + self.benchmarks[benchmark.uri] = benchmark + return reply diff --git a/compiler_gym/service/runtime/create_and_run_compiler_gym_service.py b/compiler_gym/service/runtime/create_and_run_compiler_gym_service.py new file mode 100644 index 0000000000..b40d2cf1f9 --- /dev/null +++ b/compiler_gym/service/runtime/create_and_run_compiler_gym_service.py @@ -0,0 +1,75 @@ +#! /usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""An example CompilerGym service in python.""" +import os +from concurrent import futures +from multiprocessing import cpu_count +from pathlib import Path +from tempfile import mkdtemp + +import grpc +from absl import app, flags, logging + +from compiler_gym.service.proto import compiler_gym_service_pb2_grpc +from compiler_gym.service.runtime.compiler_gym_service import CompilerGymService +from compiler_gym.util.filesystem import atomic_file_write + +flags.DEFINE_string("working_dir", "", "Path to use as service working directory") +flags.DEFINE_integer("port", 0, "The service listening port") +flags.DEFINE_integer("nproc", cpu_count(), "The number of server worker threads") +flags.DEFINE_integer("logbuflevel", 0, "Flag for compatability with C++ service.") +FLAGS = flags.FLAGS + +MAX_MESSAGE_SIZE_IN_BYTES = 512 * 1024 * 1024 + + +def create_and_run_compiler_gym_service(compilation_session_type): + def main(argv): + argv = [x for x in argv if x.strip()] + if len(argv) != 1: + raise app.UsageError(f"Unrecognized arguments: {argv[1:]}") + + working_dir = Path(FLAGS.working_dir or mkdtemp(prefix="compiler_gym-service-")) + (working_dir / "logs").mkdir(exist_ok=True, parents=True) + + FLAGS.log_dir = str(working_dir / "logs") + logging.get_absl_handler().use_absl_log_file() + + # Create the service. + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=FLAGS.nproc), + options=[ + ("grpc.max_send_message_length", MAX_MESSAGE_SIZE_IN_BYTES), + ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE_IN_BYTES), + ], + ) + servicer = CompilerGymService( + working_directory=working_dir, + compilation_session_type=compilation_session_type, + ) + compiler_gym_service_pb2_grpc.add_CompilerGymServiceServicer_to_server( + servicer, server + ) + port = server.add_insecure_port("0.0.0.0:0") + + with atomic_file_write(working_dir / "port.txt", fileobj=True, mode="w") as f: + f.write(str(port)) + + with atomic_file_write(working_dir / "pid.txt", fileobj=True, mode="w") as f: + f.write(str(os.getpid())) + + logging.info( + "Service %s listening on %d, PID = %d", working_dir, port, os.getpid() + ) + + server.start() + server.wait_for_termination() + logging.fatal( + "Unreachable! grpc.server.wait_for_termination() should not return" + ) + + app.run(main)