Skip to content

Commit

Permalink
[runtime] Add an in-memory cache for Benchmark protos.
Browse files Browse the repository at this point in the history
This will be used by the CompilationSession runtime to keep track of
the Benchmark protobufs that have been sent by the user to the
service, so that CompilationSession::init() can be passed a benchmark
proto.

This is a generalization of the BenchmarkFactory class that is used by
the LLVM service to keep a bunch of llvm::Modules loaded in
memory. The same class is implemented twice in C++ and Python using
the same semantics and with the same tests.

The cache has a target maximum size based on the number of bytes of
its elements. When this size is reached, benchamrks are evicted using
a random policy. The idea behind random cache eviction is that this
cache will be large enough by default to store a good number of
benchmarks, so exceeding the max cache size implies a training loop in
which random programs are selected from a very large pool, rather than
smaller pool where an LRU policy would be better.

Issue facebookresearch#254.
  • Loading branch information
ChrisCummins authored and bwasti committed Aug 3, 2021
1 parent 40f330d commit 29e723e
Show file tree
Hide file tree
Showing 10 changed files with 571 additions and 1 deletion.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ py_library(
"//compiler_gym/datasets",
"//compiler_gym/envs",
"//compiler_gym/service",
"//compiler_gym/service/runtime",
"//compiler_gym/spaces",
"//compiler_gym/views",
"//examples/sensitivity_analysis:action_sensitivity_analysis",
Expand Down
40 changes: 40 additions & 0 deletions compiler_gym/service/runtime/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# This package implements the CompilerGym service runtime, which is the utility
# code that creates RPC servers and dispatches to CompilationServices.
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_python//python:defs.bzl", "py_library")

py_library(
name = "runtime",
srcs = ["__init__.py"],
visibility = ["//visibility:public"],
deps = [
":benchmark_cache",
],
)

py_library(
name = "benchmark_cache",
srcs = ["benchmark_cache.py"],
visibility = ["//tests/service/runtime:__subpackages__"],
deps = [
"//compiler_gym/service/proto",
],
)

cc_library(
name = "BenchmarkCache",
srcs = ["BenchmarkCache.cc"],
hdrs = ["BenchmarkCache.h"],
visibility = ["//tests/service/runtime:__subpackages__"],
deps = [
"//compiler_gym/service/proto:compiler_gym_service_cc",
"@boost//:filesystem",
"@com_github_grpc_grpc//:grpc++",
"@glog",
],
)
83 changes: 83 additions & 0 deletions compiler_gym/service/runtime/BenchmarkCache.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include "compiler_gym/service/runtime/BenchmarkCache.h"

#include <glog/logging.h>

using grpc::Status;
using grpc::StatusCode;

namespace compiler_gym::runtime {

BenchmarkCache::BenchmarkCache(std::optional<std::mt19937_64> rand, size_t maxSizeInBytes)
: rand_(rand.has_value() ? *rand : std::mt19937_64(std::random_device()())),
maxSizeInBytes_(maxSizeInBytes),
sizeInBytes_(0){};

const Benchmark* BenchmarkCache::get(const std::string& uri) const {
auto it = benchmarks_.find(uri);
if (it == benchmarks_.end()) {
return nullptr;
}

return &it->second;
}

void BenchmarkCache::add(const Benchmark&& benchmark) {
VLOG(3) << "Caching benchmark " << benchmark.uri() << ". Cache size = " << sizeInBytes()
<< " bytes, " << size() << " items";

// Remove any existing value to keep the cache size consistent.
const auto it = benchmarks_.find(benchmark.uri());
if (it != benchmarks_.end()) {
const size_t replacedSize = it->second.ByteSizeLong();
benchmarks_.erase(it);
sizeInBytes_ -= replacedSize;
}

const size_t size = benchmark.ByteSizeLong();
if (sizeInBytes() + size > maxSizeInBytes()) {
if (size > maxSizeInBytes()) {
LOG(WARNING) << "Adding new benchmark with size " << size
<< " bytes exceeds total target cache size of " << maxSizeInBytes() << " bytes";
} else {
VLOG(3) << "Adding new benchmark with size " << size << " bytes exceeds maximum size "
<< maxSizeInBytes() << " bytes, " << this->size() << " items";
}
prune();
}

benchmarks_.insert({benchmark.uri(), std::move(benchmark)});
sizeInBytes_ += size;
}

void BenchmarkCache::prune(std::optional<size_t> targetSize) {
int evicted = 0;
targetSize = targetSize.has_value() ? targetSize : maxSizeInBytes() / 2;

while (size() && sizeInBytes() > targetSize) {
// Select a benchmark randomly.
std::uniform_int_distribution<size_t> distribution(0, benchmarks_.size() - 1);
size_t index = distribution(rand_);
auto iterator = std::next(std::begin(benchmarks_), index);

// Evict the benchmark from the pool of loaded benchmarks.
++evicted;
sizeInBytes_ -= iterator->second.ByteSizeLong();
benchmarks_.erase(iterator);
}

if (evicted) {
VLOG(2) << "Evicted " << evicted << " benchmarks from cache. Benchmark cache "
<< "size now " << sizeInBytes() << " bytes, " << benchmarks_.size() << " items";
}
}

void BenchmarkCache::setMaxSizeInBytes(size_t maxSizeInBytes) {
maxSizeInBytes_ = maxSizeInBytes;
prune(maxSizeInBytes);
}

} // namespace compiler_gym::runtime
54 changes: 54 additions & 0 deletions compiler_gym/service/runtime/BenchmarkCache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#pragma once

#include <grpcpp/grpcpp.h>

#include <memory>
#include <mutex>
#include <optional>
#include <random>

#include "boost/filesystem.hpp"
#include "compiler_gym/service/proto/compiler_gym_service.pb.h"

namespace compiler_gym::runtime {

constexpr size_t kEvictionSizeInBytes = 512 * 1024 * 1024;

// An in-memory cache of Benchmark protocol buffers.
//
// This object caches Benchmark messages by URI. Once the cache reaches a
// predetermined size, benchmarks are evicted randomly until the capacity is
// reduced to 50%.
class BenchmarkCache {
public:
BenchmarkCache(std::optional<std::mt19937_64> rand = std::nullopt,
size_t maxSizeInBytes = kEvictionSizeInBytes);

// The pointer set by benchmark is valid only until the next call to add().
const Benchmark* get(const std::string& uri) const;

// Move-insert the given benchmark to the cache.
void add(const Benchmark&& benchmark);

inline size_t size() const { return benchmarks_.size(); };
inline size_t sizeInBytes() const { return sizeInBytes_; };
inline size_t maxSizeInBytes() const { return maxSizeInBytes_; };

void setMaxSizeInBytes(size_t maxSizeInBytes);

// Evict benchmarks randomly to reduce the capacity below 50%.
void prune(std::optional<size_t> targetSize = std::nullopt);

private:
std::unordered_map<std::string, const Benchmark> benchmarks_;

std::mt19937_64 rand_;
size_t maxSizeInBytes_;
size_t sizeInBytes_;
};

} // namespace compiler_gym::runtime
4 changes: 4 additions & 0 deletions compiler_gym/service/runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
128 changes: 128 additions & 0 deletions compiler_gym/service/runtime/benchmark_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict, Optional

import numpy as np

from compiler_gym.service.proto import Benchmark

MAX_SIZE_IN_BYTES = 512 * 104 * 1024


class BenchmarkCache:
"""An in-memory cache of Benchmark messages.
This object caches Benchmark messages by URI. Once the cache reaches a
predetermined size, benchmarks are evicted randomly until the capacity is
reduced to 50%.
"""

def __init__(
self,
max_size_in_bytes: int = MAX_SIZE_IN_BYTES,
logger: Optional[logging.Logger] = None,
rng: Optional[np.random.Generator] = None,
):
self.rng = rng or np.random.default_rng()
self._max_size_in_bytes = max_size_in_bytes
self.logger = logger or logging.getLogger("compiler_gym")

self._benchmarks: Dict[str, Benchmark] = {}
self._size_in_bytes = 0

def __getitem__(self, uri: str) -> Benchmark:
"""Get a benchmark by URI. Raises KeyError."""
item = self._benchmarks.get(uri)
if item is None:
raise KeyError(uri)
return item

def __contains__(self, uri: str):
"""Whether URI is in cache."""
return uri in self._benchmarks

def __setitem__(self, uri: str, benchmark: Benchmark):
"""Add benchmark to cache."""
self.logger.debug(
"Caching benchmark %s. Cache size = %d bytes, %d items",
uri,
self.size_in_bytes,
self.size,
)

# Remove any existing value to keep the cache size consistent.
if uri in self._benchmarks:
self._size_in_bytes -= self._benchmarks[uri].ByteSize()
del self._benchmarks[uri]

size = benchmark.ByteSize()
if self.size_in_bytes + size > self.max_size_in_bytes:
if size > self.max_size_in_bytes:
self.logger.warning(
"Adding new benchmark with size %d bytes exceeds total "
"target cache size of %d bytes",
size,
self.max_size_in_bytes,
)
else:
self.logger.debug(
"Adding new benchmark with size %d bytes "
"exceeds maximum size %d bytes, %d items",
size,
self.max_size_in_bytes,
self.size,
)
self.prune()

self._benchmarks[uri] = benchmark
self._size_in_bytes += size

def prune(self, target_size_in_bytes: Optional[int] = None) -> None:
"""Evict benchmarks randomly to reduce the capacity below 50%."""
evicted = 0
target_size_in_bytes = (
self.max_size_in_bytes // 2
if target_size_in_bytes is None
else target_size_in_bytes
)

while self.size and self.size_in_bytes > target_size_in_bytes:
evicted += 1
key = self.rng.choice(list(self._benchmarks.keys()))
self._size_in_bytes -= self._benchmarks[key].ByteSize()
del self._benchmarks[key]

if evicted:
self.logger.info(
"Evicted %d benchmarks from cache. "
"Benchmark cache size now %d bytes, %d items",
evicted,
self.size_in_bytes,
self.size,
)

@property
def size(self) -> int:
"""The number of items in the cache."""
return len(self._benchmarks)

@property
def size_in_bytes(self) -> int:
"""The combined size of the elements in the cache, excluding the
cache overhead.
"""
return self._size_in_bytes

@property
def max_size_in_bytes(self) -> int:
"""The maximum size of the cache."""
return self._max_size_in_bytes

@max_size_in_bytes.setter
def max_size_in_bytes(self, value: int) -> None:
"""Set a new maximum cache size."""
self._max_size_in_bytes = value
self.prune(target_size_in_bytes=value)
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def get_tag(self):
"compiler_gym.envs",
"compiler_gym.envs",
"compiler_gym.leaderboard",
"compiler_gym.service.proto",
"compiler_gym.service",
"compiler_gym.service.proto",
"compiler_gym.service.runtime",
"compiler_gym.spaces",
"compiler_gym.third_party.autophase",
"compiler_gym.third_party.inst2vec",
Expand Down
27 changes: 27 additions & 0 deletions tests/service/runtime/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
load("@rules_cc//cc:defs.bzl", "cc_test")
load("@rules_python//python:defs.bzl", "py_test")

py_test(
name = "benchmark_cache_test",
srcs = ["benchmark_cache_test.py"],
deps = [
"//compiler_gym/service/proto",
"//compiler_gym/service/runtime:benchmark_cache",
"//tests:test_main",
],
)

cc_test(
name = "BenchmarkCacheTest",
srcs = ["BenchmarkCacheTest.cc"],
deps = [
"//compiler_gym/service/proto:compiler_gym_service_cc",
"//compiler_gym/service/runtime:BenchmarkCache",
"//tests:TestMain",
"@gtest",
],
)
Loading

0 comments on commit 29e723e

Please sign in to comment.