forked from facebookresearch/CompilerGym
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[runtime] Add an in-memory cache for Benchmark protos.
This will be used by the CompilationSession runtime to keep track of the Benchmark protobufs that have been sent by the user to the service, so that CompilationSession::init() can be passed a benchmark proto. This is a generalization of the BenchmarkFactory class that is used by the LLVM service to keep a bunch of llvm::Modules loaded in memory. The same class is implemented twice in C++ and Python using the same semantics and with the same tests. The cache has a target maximum size based on the number of bytes of its elements. When this size is reached, benchamrks are evicted using a random policy. The idea behind random cache eviction is that this cache will be large enough by default to store a good number of benchmarks, so exceeding the max cache size implies a training loop in which random programs are selected from a very large pool, rather than smaller pool where an LRU policy would be better. Issue facebookresearch#254.
- Loading branch information
1 parent
40f330d
commit 29e723e
Showing
10 changed files
with
571 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
# This package implements the CompilerGym service runtime, which is the utility | ||
# code that creates RPC servers and dispatches to CompilationServices. | ||
load("@rules_cc//cc:defs.bzl", "cc_library") | ||
load("@rules_python//python:defs.bzl", "py_library") | ||
|
||
py_library( | ||
name = "runtime", | ||
srcs = ["__init__.py"], | ||
visibility = ["//visibility:public"], | ||
deps = [ | ||
":benchmark_cache", | ||
], | ||
) | ||
|
||
py_library( | ||
name = "benchmark_cache", | ||
srcs = ["benchmark_cache.py"], | ||
visibility = ["//tests/service/runtime:__subpackages__"], | ||
deps = [ | ||
"//compiler_gym/service/proto", | ||
], | ||
) | ||
|
||
cc_library( | ||
name = "BenchmarkCache", | ||
srcs = ["BenchmarkCache.cc"], | ||
hdrs = ["BenchmarkCache.h"], | ||
visibility = ["//tests/service/runtime:__subpackages__"], | ||
deps = [ | ||
"//compiler_gym/service/proto:compiler_gym_service_cc", | ||
"@boost//:filesystem", | ||
"@com_github_grpc_grpc//:grpc++", | ||
"@glog", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
#include "compiler_gym/service/runtime/BenchmarkCache.h" | ||
|
||
#include <glog/logging.h> | ||
|
||
using grpc::Status; | ||
using grpc::StatusCode; | ||
|
||
namespace compiler_gym::runtime { | ||
|
||
BenchmarkCache::BenchmarkCache(std::optional<std::mt19937_64> rand, size_t maxSizeInBytes) | ||
: rand_(rand.has_value() ? *rand : std::mt19937_64(std::random_device()())), | ||
maxSizeInBytes_(maxSizeInBytes), | ||
sizeInBytes_(0){}; | ||
|
||
const Benchmark* BenchmarkCache::get(const std::string& uri) const { | ||
auto it = benchmarks_.find(uri); | ||
if (it == benchmarks_.end()) { | ||
return nullptr; | ||
} | ||
|
||
return &it->second; | ||
} | ||
|
||
void BenchmarkCache::add(const Benchmark&& benchmark) { | ||
VLOG(3) << "Caching benchmark " << benchmark.uri() << ". Cache size = " << sizeInBytes() | ||
<< " bytes, " << size() << " items"; | ||
|
||
// Remove any existing value to keep the cache size consistent. | ||
const auto it = benchmarks_.find(benchmark.uri()); | ||
if (it != benchmarks_.end()) { | ||
const size_t replacedSize = it->second.ByteSizeLong(); | ||
benchmarks_.erase(it); | ||
sizeInBytes_ -= replacedSize; | ||
} | ||
|
||
const size_t size = benchmark.ByteSizeLong(); | ||
if (sizeInBytes() + size > maxSizeInBytes()) { | ||
if (size > maxSizeInBytes()) { | ||
LOG(WARNING) << "Adding new benchmark with size " << size | ||
<< " bytes exceeds total target cache size of " << maxSizeInBytes() << " bytes"; | ||
} else { | ||
VLOG(3) << "Adding new benchmark with size " << size << " bytes exceeds maximum size " | ||
<< maxSizeInBytes() << " bytes, " << this->size() << " items"; | ||
} | ||
prune(); | ||
} | ||
|
||
benchmarks_.insert({benchmark.uri(), std::move(benchmark)}); | ||
sizeInBytes_ += size; | ||
} | ||
|
||
void BenchmarkCache::prune(std::optional<size_t> targetSize) { | ||
int evicted = 0; | ||
targetSize = targetSize.has_value() ? targetSize : maxSizeInBytes() / 2; | ||
|
||
while (size() && sizeInBytes() > targetSize) { | ||
// Select a benchmark randomly. | ||
std::uniform_int_distribution<size_t> distribution(0, benchmarks_.size() - 1); | ||
size_t index = distribution(rand_); | ||
auto iterator = std::next(std::begin(benchmarks_), index); | ||
|
||
// Evict the benchmark from the pool of loaded benchmarks. | ||
++evicted; | ||
sizeInBytes_ -= iterator->second.ByteSizeLong(); | ||
benchmarks_.erase(iterator); | ||
} | ||
|
||
if (evicted) { | ||
VLOG(2) << "Evicted " << evicted << " benchmarks from cache. Benchmark cache " | ||
<< "size now " << sizeInBytes() << " bytes, " << benchmarks_.size() << " items"; | ||
} | ||
} | ||
|
||
void BenchmarkCache::setMaxSizeInBytes(size_t maxSizeInBytes) { | ||
maxSizeInBytes_ = maxSizeInBytes; | ||
prune(maxSizeInBytes); | ||
} | ||
|
||
} // namespace compiler_gym::runtime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. | ||
// | ||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
#pragma once | ||
|
||
#include <grpcpp/grpcpp.h> | ||
|
||
#include <memory> | ||
#include <mutex> | ||
#include <optional> | ||
#include <random> | ||
|
||
#include "boost/filesystem.hpp" | ||
#include "compiler_gym/service/proto/compiler_gym_service.pb.h" | ||
|
||
namespace compiler_gym::runtime { | ||
|
||
constexpr size_t kEvictionSizeInBytes = 512 * 1024 * 1024; | ||
|
||
// An in-memory cache of Benchmark protocol buffers. | ||
// | ||
// This object caches Benchmark messages by URI. Once the cache reaches a | ||
// predetermined size, benchmarks are evicted randomly until the capacity is | ||
// reduced to 50%. | ||
class BenchmarkCache { | ||
public: | ||
BenchmarkCache(std::optional<std::mt19937_64> rand = std::nullopt, | ||
size_t maxSizeInBytes = kEvictionSizeInBytes); | ||
|
||
// The pointer set by benchmark is valid only until the next call to add(). | ||
const Benchmark* get(const std::string& uri) const; | ||
|
||
// Move-insert the given benchmark to the cache. | ||
void add(const Benchmark&& benchmark); | ||
|
||
inline size_t size() const { return benchmarks_.size(); }; | ||
inline size_t sizeInBytes() const { return sizeInBytes_; }; | ||
inline size_t maxSizeInBytes() const { return maxSizeInBytes_; }; | ||
|
||
void setMaxSizeInBytes(size_t maxSizeInBytes); | ||
|
||
// Evict benchmarks randomly to reduce the capacity below 50%. | ||
void prune(std::optional<size_t> targetSize = std::nullopt); | ||
|
||
private: | ||
std::unordered_map<std::string, const Benchmark> benchmarks_; | ||
|
||
std::mt19937_64 rand_; | ||
size_t maxSizeInBytes_; | ||
size_t sizeInBytes_; | ||
}; | ||
|
||
} // namespace compiler_gym::runtime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import logging | ||
from typing import Dict, Optional | ||
|
||
import numpy as np | ||
|
||
from compiler_gym.service.proto import Benchmark | ||
|
||
MAX_SIZE_IN_BYTES = 512 * 104 * 1024 | ||
|
||
|
||
class BenchmarkCache: | ||
"""An in-memory cache of Benchmark messages. | ||
This object caches Benchmark messages by URI. Once the cache reaches a | ||
predetermined size, benchmarks are evicted randomly until the capacity is | ||
reduced to 50%. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
max_size_in_bytes: int = MAX_SIZE_IN_BYTES, | ||
logger: Optional[logging.Logger] = None, | ||
rng: Optional[np.random.Generator] = None, | ||
): | ||
self.rng = rng or np.random.default_rng() | ||
self._max_size_in_bytes = max_size_in_bytes | ||
self.logger = logger or logging.getLogger("compiler_gym") | ||
|
||
self._benchmarks: Dict[str, Benchmark] = {} | ||
self._size_in_bytes = 0 | ||
|
||
def __getitem__(self, uri: str) -> Benchmark: | ||
"""Get a benchmark by URI. Raises KeyError.""" | ||
item = self._benchmarks.get(uri) | ||
if item is None: | ||
raise KeyError(uri) | ||
return item | ||
|
||
def __contains__(self, uri: str): | ||
"""Whether URI is in cache.""" | ||
return uri in self._benchmarks | ||
|
||
def __setitem__(self, uri: str, benchmark: Benchmark): | ||
"""Add benchmark to cache.""" | ||
self.logger.debug( | ||
"Caching benchmark %s. Cache size = %d bytes, %d items", | ||
uri, | ||
self.size_in_bytes, | ||
self.size, | ||
) | ||
|
||
# Remove any existing value to keep the cache size consistent. | ||
if uri in self._benchmarks: | ||
self._size_in_bytes -= self._benchmarks[uri].ByteSize() | ||
del self._benchmarks[uri] | ||
|
||
size = benchmark.ByteSize() | ||
if self.size_in_bytes + size > self.max_size_in_bytes: | ||
if size > self.max_size_in_bytes: | ||
self.logger.warning( | ||
"Adding new benchmark with size %d bytes exceeds total " | ||
"target cache size of %d bytes", | ||
size, | ||
self.max_size_in_bytes, | ||
) | ||
else: | ||
self.logger.debug( | ||
"Adding new benchmark with size %d bytes " | ||
"exceeds maximum size %d bytes, %d items", | ||
size, | ||
self.max_size_in_bytes, | ||
self.size, | ||
) | ||
self.prune() | ||
|
||
self._benchmarks[uri] = benchmark | ||
self._size_in_bytes += size | ||
|
||
def prune(self, target_size_in_bytes: Optional[int] = None) -> None: | ||
"""Evict benchmarks randomly to reduce the capacity below 50%.""" | ||
evicted = 0 | ||
target_size_in_bytes = ( | ||
self.max_size_in_bytes // 2 | ||
if target_size_in_bytes is None | ||
else target_size_in_bytes | ||
) | ||
|
||
while self.size and self.size_in_bytes > target_size_in_bytes: | ||
evicted += 1 | ||
key = self.rng.choice(list(self._benchmarks.keys())) | ||
self._size_in_bytes -= self._benchmarks[key].ByteSize() | ||
del self._benchmarks[key] | ||
|
||
if evicted: | ||
self.logger.info( | ||
"Evicted %d benchmarks from cache. " | ||
"Benchmark cache size now %d bytes, %d items", | ||
evicted, | ||
self.size_in_bytes, | ||
self.size, | ||
) | ||
|
||
@property | ||
def size(self) -> int: | ||
"""The number of items in the cache.""" | ||
return len(self._benchmarks) | ||
|
||
@property | ||
def size_in_bytes(self) -> int: | ||
"""The combined size of the elements in the cache, excluding the | ||
cache overhead. | ||
""" | ||
return self._size_in_bytes | ||
|
||
@property | ||
def max_size_in_bytes(self) -> int: | ||
"""The maximum size of the cache.""" | ||
return self._max_size_in_bytes | ||
|
||
@max_size_in_bytes.setter | ||
def max_size_in_bytes(self, value: int) -> None: | ||
"""Set a new maximum cache size.""" | ||
self._max_size_in_bytes = value | ||
self.prune(target_size_in_bytes=value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
load("@rules_cc//cc:defs.bzl", "cc_test") | ||
load("@rules_python//python:defs.bzl", "py_test") | ||
|
||
py_test( | ||
name = "benchmark_cache_test", | ||
srcs = ["benchmark_cache_test.py"], | ||
deps = [ | ||
"//compiler_gym/service/proto", | ||
"//compiler_gym/service/runtime:benchmark_cache", | ||
"//tests:test_main", | ||
], | ||
) | ||
|
||
cc_test( | ||
name = "BenchmarkCacheTest", | ||
srcs = ["BenchmarkCacheTest.cc"], | ||
deps = [ | ||
"//compiler_gym/service/proto:compiler_gym_service_cc", | ||
"//compiler_gym/service/runtime:BenchmarkCache", | ||
"//tests:TestMain", | ||
"@gtest", | ||
], | ||
) |
Oops, something went wrong.