diff --git a/tensorflow_serving/config/BUILD b/tensorflow_serving/config/BUILD index 21c198ef4e7..60be85c9ef1 100644 --- a/tensorflow_serving/config/BUILD +++ b/tensorflow_serving/config/BUILD @@ -88,6 +88,15 @@ serving_proto_library_py( ], ) +serving_proto_library( + name = "monitoring_config_proto", + srcs = ["monitoring_config.proto"], + cc_api_version = 2, + java_api_version = 2, + deps = [ + ], +) + serving_proto_library( name = "ssl_config_proto", srcs = ["ssl_config.proto"], diff --git a/tensorflow_serving/config/monitoring_config.proto b/tensorflow_serving/config/monitoring_config.proto new file mode 100644 index 00000000000..9da3700de46 --- /dev/null +++ b/tensorflow_serving/config/monitoring_config.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package tensorflow.serving; +option cc_enable_arenas = true; + +// Configuration for Prometheus monitoring. +message PrometheusConfig { + // Whether to expose Prometheus metrics. + bool enable = 1; + + // The endpoint to expose Prometheus metrics. + // If not specified, PrometheusExporter::kPrometheusPath value is used. + string path = 2; +} + +// Configuration for monitoring. +message MonitoringConfig { + PrometheusConfig prometheus_config = 1; +} diff --git a/tensorflow_serving/model_servers/BUILD b/tensorflow_serving/model_servers/BUILD index 3aabb9d57db..a6d6eb06d99 100644 --- a/tensorflow_serving/model_servers/BUILD +++ b/tensorflow_serving/model_servers/BUILD @@ -243,10 +243,11 @@ cc_library( deps = [ ":http_rest_api_handler", ":server_core", + "//tensorflow_serving/config:monitoring_config_proto", + "//tensorflow_serving/util:prometheus_exporter", "//tensorflow_serving/util:threadpool_executor", "//tensorflow_serving/util/net_http/server/public:http_server", "//tensorflow_serving/util/net_http/server/public:http_server_api", - "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", "@org_tensorflow//tensorflow/core:lib", @@ -318,6 +319,7 @@ cc_library( "@com_google_absl//absl/memory", "@org_tensorflow//tensorflow/core:protos_all_cc", "//tensorflow_serving/config:model_server_config_proto", + "//tensorflow_serving/config:monitoring_config_proto", "//tensorflow_serving/config:ssl_config_proto", "//tensorflow_serving/core:availability_preserving_policy", "//tensorflow_serving/servables/tensorflow:session_bundle_config_proto", @@ -366,6 +368,7 @@ py_test( "//tensorflow_serving/servables/tensorflow/testdata:half_plus_two/00000123/export.data-00000-of-00001", "//tensorflow_serving/servables/tensorflow/testdata:half_plus_two/00000123/export.index", "//tensorflow_serving/servables/tensorflow/testdata:half_plus_two/00000123/export.meta", + "//tensorflow_serving/servables/tensorflow/testdata:monitoring_config.txt", "//tensorflow_serving/servables/tensorflow/testdata:saved_model_half_plus_three/00000123/assets/foo.txt", "//tensorflow_serving/servables/tensorflow/testdata:saved_model_half_plus_three/00000123/saved_model.pb", "//tensorflow_serving/servables/tensorflow/testdata:saved_model_half_plus_three/00000123/variables/variables.data-00000-of-00001", diff --git a/tensorflow_serving/model_servers/http_server.cc b/tensorflow_serving/model_servers/http_server.cc index 4f4ad460516..584b5db541d 100644 --- a/tensorflow_serving/model_servers/http_server.cc +++ b/tensorflow_serving/model_servers/http_server.cc @@ -14,8 +14,10 @@ limitations under the License. ==============================================================================*/ #include +#include #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "re2/re2.h" #include "tensorflow/core/platform/env.h" @@ -26,6 +28,7 @@ limitations under the License. #include "tensorflow_serving/util/net_http/server/public/httpserver.h" #include "tensorflow_serving/util/net_http/server/public/response_code_enum.h" #include "tensorflow_serving/util/net_http/server/public/server_request_interface.h" +#include "tensorflow_serving/util/prometheus_exporter.h" #include "tensorflow_serving/util/threadpool_executor.h" namespace tensorflow { @@ -79,6 +82,35 @@ net_http::HTTPStatusCode ToHTTPStatusCode(const Status& status) { } } +void ProcessPrometheusRequest(PrometheusExporter* exporter, + const PrometheusConfig& prometheus_config, + net_http::ServerRequestInterface* req) { + std::vector> headers; + headers.push_back({"Content-Type", "text/plain"}); + string output; + Status status; + // Check if url matches the path. + if (req->uri_path() != prometheus_config.path()) { + output = absl::StrFormat("Unexpected path: %s. Should be %s", + req->uri_path(), prometheus_config.path()); + status = Status(error::Code::INVALID_ARGUMENT, output); + } else { + status = exporter->GeneratePage(&output); + } + const net_http::HTTPStatusCode http_status = ToHTTPStatusCode(status); + // Note: we add headers+output for non successful status too, in case the + // output contains details about the error (e.g. error messages). + for (const auto& kv : headers) { + req->OverwriteResponseHeader(kv.first, kv.second); + } + req->WriteResponseString(output); + if (http_status != net_http::HTTPStatusCode::OK) { + VLOG(1) << "Error Processing prometheus metrics request. Error: " + << status.ToString(); + } + req->ReplyWithStatus(http_status); +} + class RequestExecutor final : public net_http::EventExecutor { public: explicit RequestExecutor(int num_threads) @@ -147,7 +179,8 @@ class RestApiRequestDispatcher { } // namespace std::unique_ptr CreateAndStartHttpServer( - int port, int num_threads, int timeout_in_ms, ServerCore* core) { + int port, int num_threads, int timeout_in_ms, + const MonitoringConfig& monitoring_config, ServerCore* core) { auto options = absl::make_unique(); options->AddPort(static_cast(port)); options->SetExecutor(absl::make_unique(num_threads)); @@ -157,6 +190,20 @@ std::unique_ptr CreateAndStartHttpServer( return nullptr; } + // Register handler for prometheus metric endpoint. + if (monitoring_config.prometheus_config().enable()) { + std::shared_ptr exporter = + std::make_shared(); + net_http::RequestHandlerOptions prometheus_request_options; + PrometheusConfig prometheus_config = monitoring_config.prometheus_config(); + server->RegisterRequestHandler( + monitoring_config.prometheus_config().path(), + [exporter, prometheus_config](net_http::ServerRequestInterface* req) { + ProcessPrometheusRequest(exporter.get(), prometheus_config, req); + }, + prometheus_request_options); + } + std::shared_ptr dispatcher = std::make_shared(timeout_in_ms, core); net_http::RequestHandlerOptions handler_options; diff --git a/tensorflow_serving/model_servers/http_server.h b/tensorflow_serving/model_servers/http_server.h index 5beea8a5d82..f346ff45383 100644 --- a/tensorflow_serving/model_servers/http_server.h +++ b/tensorflow_serving/model_servers/http_server.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow_serving/config/monitoring_config.pb.h" #include "tensorflow_serving/util/net_http/server/public/httpserver_interface.h" namespace tensorflow { @@ -30,7 +31,8 @@ class ServerCore; // // The returned server is in a state of accepting new requests. std::unique_ptr CreateAndStartHttpServer( - int port, int num_threads, int timeout_in_ms, ServerCore* core); + int port, int num_threads, int timeout_in_ms, + const MonitoringConfig& monitoring_config, ServerCore* core); } // namespace serving } // namespace tensorflow diff --git a/tensorflow_serving/model_servers/main.cc b/tensorflow_serving/model_servers/main.cc index 3bc94d7398e..4ba9e2c05e3 100644 --- a/tensorflow_serving/model_servers/main.cc +++ b/tensorflow_serving/model_servers/main.cc @@ -135,7 +135,11 @@ int main(int argc, char** argv) { "Enables model warmup, which triggers lazy " "initializations (such as TF optimizations) at load " "time, to reduce first request latency."), - tensorflow::Flag("version", &display_version, "Display version")}; + tensorflow::Flag("version", &display_version, "Display version"), + tensorflow::Flag( + "monitoring_config_file", &options.monitoring_config_file, + "If non-empty, read an ascii MonitoringConfig protobuf from " + "the supplied file name")}; const auto& usage = tensorflow::Flags::Usage(argv[0], flag_list); if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) { diff --git a/tensorflow_serving/model_servers/server.cc b/tensorflow_serving/model_servers/server.cc index 4b654f2520e..92178d55255 100644 --- a/tensorflow_serving/model_servers/server.cc +++ b/tensorflow_serving/model_servers/server.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow_serving/config/model_server_config.pb.h" +#include "tensorflow_serving/config/monitoring_config.pb.h" #include "tensorflow_serving/config/ssl_config.pb.h" #include "tensorflow_serving/core/availability_preserving_policy.h" #include "tensorflow_serving/model_servers/grpc_status_util.h" @@ -285,9 +286,15 @@ Status Server::BuildAndStart(const Options& server_options) { if (server_options.http_port != server_options.grpc_port) { const string server_address = "localhost:" + std::to_string(server_options.http_port); + MonitoringConfig monitoring_config; + if (!server_options.monitoring_config_file.empty()) { + monitoring_config = ReadProtoFromFile( + server_options.monitoring_config_file); + } http_server_ = CreateAndStartHttpServer( server_options.http_port, server_options.http_num_threads, - server_options.http_timeout_in_ms, server_core_.get()); + server_options.http_timeout_in_ms, monitoring_config, + server_core_.get()); if (http_server_ != nullptr) { LOG(INFO) << "Exporting HTTP/REST API at:" << server_address << " ..."; } else { diff --git a/tensorflow_serving/model_servers/server.h b/tensorflow_serving/model_servers/server.h index ebc5c7be4b6..eae27d6a978 100644 --- a/tensorflow_serving/model_servers/server.h +++ b/tensorflow_serving/model_servers/server.h @@ -65,6 +65,7 @@ class Server { tensorflow::string ssl_config_file; string model_config_file; bool enable_model_warmup = true; + tensorflow::string monitoring_config_file; Options(); }; diff --git a/tensorflow_serving/model_servers/tensorflow_model_server_test.py b/tensorflow_serving/model_servers/tensorflow_model_server_test.py index e0a24919329..090ca0f4245 100644 --- a/tensorflow_serving/model_servers/tensorflow_model_server_test.py +++ b/tensorflow_serving/model_servers/tensorflow_model_server_test.py @@ -117,6 +117,7 @@ def GetArgsKey(*args, **kwargs): def RunServer(model_name, model_path, model_config_file=None, + monitoring_config_file=None, batching_parameters_file=None, grpc_channel_arguments='', wait_for_server_ready=True, @@ -131,6 +132,7 @@ def RunServer(model_name, model_name: Name of model. model_path: Path to model. model_config_file: Path to model config file. + monitoring_config_file: Path to the monitoring config file. batching_parameters_file: Path to batching parameters. grpc_channel_arguments: Custom gRPC args for server. wait_for_server_ready: Wait for gRPC port to be ready. @@ -165,6 +167,9 @@ def RunServer(model_name, else: raise ValueError('Both model_config_file and model_path cannot be empty!') + if monitoring_config_file: + command += ' --monitoring_config_file=' + monitoring_config_file + if batching_parameters_file: command += ' --enable_batching' command += ' --batching_parameters_file=' + batching_parameters_file @@ -287,6 +292,10 @@ def _GetBatchingParametersFile(self): """Returns a path to a batching configuration file.""" return os.path.join(self.testdata_dir, 'batching_config.txt') + def _GetMonitoringConfigFile(self): + """Returns a path to a monitoring configuration file.""" + return os.path.join(self.testdata_dir, 'monitoring_config.txt') + def _VerifyModelSpec(self, actual_model_spec, exp_model_name, @@ -642,6 +651,27 @@ def testGetStatusREST(self): }] }) + def testPrometheusEndpoint(self): + """Test ModelStatus implementation over REST API with columnar inputs.""" + model_path = self._GetSavedModelBundlePath() + host, port = TensorflowModelServerTest.RunServer( + 'default', + model_path, + monitoring_config_file=self._GetMonitoringConfigFile())[2].split(':') + + # Prepare request + url = 'http://{}:{}/monitoring/prometheus/metrics'.format(host, port) + + # Send request + resp_data = None + try: + resp_data = CallREST(url, None) + except Exception as e: # pylint: disable=broad-except + self.fail('Request failed with error: {}'.format(e)) + + # Verify that there should be some metric type information. + self.assertIn('# TYPE', resp_data) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_serving/servables/tensorflow/testdata/BUILD b/tensorflow_serving/servables/tensorflow/testdata/BUILD index b2520bafea9..0ed3e6fe0fc 100644 --- a/tensorflow_serving/servables/tensorflow/testdata/BUILD +++ b/tensorflow_serving/servables/tensorflow/testdata/BUILD @@ -86,4 +86,5 @@ exports_files([ "good_model_config.txt", "bad_model_config.txt", "batching_config.txt", + "monitoring_config.txt", ]) diff --git a/tensorflow_serving/servables/tensorflow/testdata/monitoring_config.txt b/tensorflow_serving/servables/tensorflow/testdata/monitoring_config.txt new file mode 100644 index 00000000000..3ae0fa2b3b9 --- /dev/null +++ b/tensorflow_serving/servables/tensorflow/testdata/monitoring_config.txt @@ -0,0 +1,4 @@ +prometheus_config: { + enable: true, + path: "/monitoring/prometheus/metrics" +} diff --git a/tensorflow_serving/util/BUILD b/tensorflow_serving/util/BUILD index afb0e21b2a8..5db01916af3 100644 --- a/tensorflow_serving/util/BUILD +++ b/tensorflow_serving/util/BUILD @@ -62,6 +62,18 @@ cc_library( ], ) +cc_library( + name = "prometheus_exporter", + srcs = ["prometheus_exporter.cc"], + hdrs = ["prometheus_exporter.h"], + deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_googlesource_code_re2//:re2", + "@org_tensorflow//tensorflow/core:lib", + ], +) + ############################################################################### # Internal targets ############################################################################### @@ -87,6 +99,18 @@ cc_test( ], ) +cc_test( + name = "prometheus_exporter_test", + size = "small", + srcs = ["prometheus_exporter_test.cc"], + deps = [ + ":prometheus_exporter", + "//tensorflow_serving/core/test_util:test_main", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/core:lib", + ], +) + cc_test( name = "event_bus_test", size = "small", diff --git a/tensorflow_serving/util/prometheus_exporter.cc b/tensorflow_serving/util/prometheus_exporter.cc new file mode 100644 index 00000000000..78268d9fb1c --- /dev/null +++ b/tensorflow_serving/util/prometheus_exporter.cc @@ -0,0 +1,188 @@ +/* Copyright 2018 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_serving/util/prometheus_exporter.h" + +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "re2/re2.h" + +namespace tensorflow { +namespace serving { + +namespace { + +string SanitizeLabelValue(const string& value) { + // Backslash and double quote have to be escaped. + string new_value = value; + // Replace \ with \\. + RE2::GlobalReplace(&new_value, "\\\\", "\\\\\\\\"); + // Replace " with \". + RE2::GlobalReplace(&new_value, "\"", "\\\\\""); + return new_value; +} + +string SanatizeLabelName(const string& name) { + // Valid format: [a-zA-Z_][a-zA-Z0-9_]* + string new_name = name; + RE2::GlobalReplace(&new_name, "[^a-zA-Z0-9]", "_"); + if (RE2::FullMatch(new_name, "^[0-9].*")) { + // Start with 0-9, prepend a underscore. + new_name = absl::StrCat("_", new_name); + } + return new_name; +} + +string GetPrometheusMetricName( + const monitoring::MetricDescriptor& metric_descriptor) { + // Valid format: [a-zA-Z_:][a-zA-Z0-9_:]* + string new_name = metric_descriptor.name; + RE2::GlobalReplace(&new_name, "[^a-zA-Z0-9_]", ":"); + if (RE2::FullMatch(new_name, "^[0-9].*")) { + // Start with 0-9, prepend a underscore. + new_name = absl::StrCat("_", new_name); + } + return new_name; +} + +void SerializeHistogram(const monitoring::MetricDescriptor& metric_descriptor, + const monitoring::PointSet& point_set, + std::vector* lines) { + // For a metric name NAME, we should output: + // NAME_bucket{le=b1} x1 + // NAME_bucket{le=b2} x2 + // NAME_bucket{le=b3} x3 ... + // NAME_sum xsum + // NAME_count xcount + string prom_metric_name = GetPrometheusMetricName(metric_descriptor); + // Type definition line. + lines->push_back(absl::StrFormat("# TYPE %s histogram", prom_metric_name)); + for (const auto& point : point_set.points) { + // Each points has differnet label values. + std::vector labels = {}; + labels.reserve(point->labels.size()); + for (const auto& label : point->labels) { + labels.push_back(absl::StrFormat("%s=\"%s\"", + SanatizeLabelName(label.name), + SanitizeLabelValue(label.value))); + } + int64 cumulative_count = 0; + string bucket_prefix = + absl::StrCat(prom_metric_name, "_bucket{", absl::StrJoin(labels, ",")); + if (!labels.empty()) { + absl::StrAppend(&bucket_prefix, ","); + } + // One bucket per line, last one should be le="Inf". + for (int i = 0; i < point->histogram_value.bucket_size(); i++) { + cumulative_count += point->histogram_value.bucket(i); + string bucket_limit = + (i < point->histogram_value.bucket_size() - 1) + ? absl::StrCat(point->histogram_value.bucket_limit(i)) + : "+Inf"; + lines->push_back(absl::StrCat( + bucket_prefix, absl::StrFormat("le=\"%s\"} ", bucket_limit), + cumulative_count)); + } + // _sum and _count. + lines->push_back(absl::StrCat(prom_metric_name, "_sum{", + absl::StrJoin(labels, ","), "} ", + point->histogram_value.sum())); + lines->push_back(absl::StrCat(prom_metric_name, "_count{", + absl::StrJoin(labels, ","), "} ", + cumulative_count)); + } +} + +void SerializeScalar(const monitoring::MetricDescriptor& metric_descriptor, + const monitoring::PointSet& point_set, + std::vector* lines) { + // A counter or gauge metric. + // The format should be: + // NAME{label=value,label=value} x time + string prom_metric_name = GetPrometheusMetricName(metric_descriptor); + string metric_type_str = "untyped"; + if (metric_descriptor.metric_kind == monitoring::MetricKind::kCumulative) { + metric_type_str = "counter"; + } else if (metric_descriptor.metric_kind == monitoring::MetricKind::kGauge) { + metric_type_str = "gauge"; + } + // Type definition line. + lines->push_back( + absl::StrFormat("# TYPE %s %s", prom_metric_name, metric_type_str)); + for (const auto& point : point_set.points) { + // Each points has differnet label values. + string name_bracket = absl::StrCat(prom_metric_name, "{"); + std::vector labels = {}; + labels.reserve(point->labels.size()); + for (const auto& label : point->labels) { + labels.push_back(absl::StrFormat("%s=\"%s\"", + SanatizeLabelName(label.name), + SanitizeLabelValue(label.value))); + } + lines->push_back(absl::StrCat(name_bracket, absl::StrJoin(labels, ","), + absl::StrFormat("} %d", point->int64_value))); + } +} + +void SerializeMetric(const monitoring::MetricDescriptor& metric_descriptor, + const monitoring::PointSet& point_set, + std::vector* lines) { + if (metric_descriptor.value_type == monitoring::ValueType::kHistogram) { + SerializeHistogram(metric_descriptor, point_set, lines); + } else { + SerializeScalar(metric_descriptor, point_set, lines); + } +} + +} // namespace + +const char* const PrometheusExporter::kPrometheusPath = + "/monitoring/prometheus/metrics"; + +PrometheusExporter::PrometheusExporter() + : collection_registry_(monitoring::CollectionRegistry::Default()) {} + +Status PrometheusExporter::GeneratePage(string* http_page) { + if (http_page == nullptr) { + return Status(error::Code::INVALID_ARGUMENT, "Http page pointer is null"); + } + monitoring::CollectionRegistry::CollectMetricsOptions collect_options; + collect_options.collect_metric_descriptors = true; + const std::unique_ptr collected_metrics = + collection_registry_->CollectMetrics(collect_options); + + const auto& descriptor_map = collected_metrics->metric_descriptor_map; + const auto& metric_map = collected_metrics->point_set_map; + + std::vector lines; + for (const auto& name_and_metric_descriptor : descriptor_map) { + const string& metric_name = name_and_metric_descriptor.first; + auto metric_iterator = metric_map.find(metric_name); + if (metric_iterator == metric_map.end()) { + // Not found. + continue; + } + SerializeMetric(*name_and_metric_descriptor.second, + *(metric_iterator->second), &lines); + } + *http_page = absl::StrJoin(lines, "\n"); + absl::StrAppend(http_page, "\n"); + return Status::OK(); +} + +} // namespace serving +} // namespace tensorflow diff --git a/tensorflow_serving/util/prometheus_exporter.h b/tensorflow_serving/util/prometheus_exporter.h new file mode 100644 index 00000000000..1d485771e37 --- /dev/null +++ b/tensorflow_serving/util/prometheus_exporter.h @@ -0,0 +1,49 @@ +/* Copyright 2018 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_SERVING_UTIL_PROMETHEUS_EXPORTER_H_ +#define TENSORFLOW_SERVING_UTIL_PROMETHEUS_EXPORTER_H_ + +#include + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/monitoring/collected_metrics.h" +#include "tensorflow/core/lib/monitoring/collection_registry.h" + +namespace tensorflow { +namespace serving { + +// Exports metrics in Prometheus monitoring format. +class PrometheusExporter { + public: + // Default path to expose the metrics. + static const char* const kPrometheusPath; + + PrometheusExporter(); + + // Generates text page in Prometheus format: + // https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-example + // If an error status returned, http_page is unchanged. + Status GeneratePage(string* http_page); + + private: + // The metrics registry. + monitoring::CollectionRegistry* collection_registry_; +}; + +} // namespace serving +} // namespace tensorflow + +#endif // TENSORFLOW_SERVING_UTIL_PROMETHEUS_EXPORTER_H_ diff --git a/tensorflow_serving/util/prometheus_exporter_test.cc b/tensorflow_serving/util/prometheus_exporter_test.cc new file mode 100644 index 00000000000..18d4d82ff3f --- /dev/null +++ b/tensorflow_serving/util/prometheus_exporter_test.cc @@ -0,0 +1,146 @@ +/* Copyright 2018 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_serving/util/prometheus_exporter.h" + +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/gauge.h" +#include "tensorflow/core/lib/monitoring/sampler.h" + +namespace tensorflow { +namespace serving { +namespace { + +TEST(PrometheusExporterTest, Counter) { + auto exporter = absl::make_unique(); + auto counter = absl::WrapUnique( + monitoring::Counter<1>::New("/test/path/total", "A counter.", "name")); + counter->GetCell("abc")->IncrementBy(2); + + string http_page; + Status status = exporter->GeneratePage(&http_page); + + string expected_result = absl::StrJoin( + {"# TYPE :test:path:total counter", ":test:path:total{name=\"abc\"} 2"}, + "\n"); + absl::StrAppend(&expected_result, "\n"); + EXPECT_EQ(http_page, expected_result); +} + +TEST(PrometheusExporterTest, Gauge) { + auto exporter = absl::make_unique(); + auto gauge = absl::WrapUnique(monitoring::Gauge::New( + "/test/path/gague", "A gauge", "x", "y")); + gauge->GetCell("abc", "def")->Set(5); + + string http_page; + Status status = exporter->GeneratePage(&http_page); + string expected_result = + absl::StrJoin({"# TYPE :test:path:gague gauge", + ":test:path:gague{x=\"abc\",y=\"def\"} 5"}, + "\n"); + absl::StrAppend(&expected_result, "\n"); + EXPECT_EQ(http_page, expected_result); +} + +TEST(PrometheusExporterTest, Histogram) { + auto exporter = absl::make_unique(); + auto histogram = absl::WrapUnique(monitoring::Sampler<1>::New( + {"/test/path/histogram", "A histogram.", "status"}, + monitoring::Buckets::Exponential(1, 2, 10))); + histogram->GetCell("good")->Add(2); + histogram->GetCell("good")->Add(20); + histogram->GetCell("good")->Add(200); + + string http_page; + Status status = exporter->GeneratePage(&http_page); + string expected_result = absl::StrJoin( + {"# TYPE :test:path:histogram histogram", + ":test:path:histogram_bucket{status=\"good\",le=\"1\"} 0", + ":test:path:histogram_bucket{status=\"good\",le=\"2\"} 0", + ":test:path:histogram_bucket{status=\"good\",le=\"4\"} 1", + ":test:path:histogram_bucket{status=\"good\",le=\"8\"} 1", + ":test:path:histogram_bucket{status=\"good\",le=\"16\"} 1", + ":test:path:histogram_bucket{status=\"good\",le=\"32\"} 2", + ":test:path:histogram_bucket{status=\"good\",le=\"64\"} 2", + ":test:path:histogram_bucket{status=\"good\",le=\"128\"} 2", + ":test:path:histogram_bucket{status=\"good\",le=\"256\"} 3", + ":test:path:histogram_bucket{status=\"good\",le=\"512\"} 3", + ":test:path:histogram_bucket{status=\"good\",le=\"+Inf\"} 3", + ":test:path:histogram_sum{status=\"good\"} 222", + ":test:path:histogram_count{status=\"good\"} 3"}, + "\n"); + absl::StrAppend(&expected_result, "\n"); + EXPECT_EQ(http_page, expected_result); +} + +TEST(PrometheusExporterTest, SanitizeLabelValue) { + auto exporter = absl::make_unique(); + auto counter = absl::WrapUnique( + monitoring::Counter<1>::New("/test/path/total", "A counter.", "name")); + // label value: "abc\" + counter->GetCell("\"abc\\\"")->IncrementBy(2); + + string http_page; + Status status = exporter->GeneratePage(&http_page); + + string expected_result = + absl::StrJoin({"# TYPE :test:path:total counter", + ":test:path:total{name=\"\\\"abc\\\\\\\"\"} 2"}, + "\n"); + absl::StrAppend(&expected_result, "\n"); + EXPECT_EQ(http_page, expected_result); +} + +TEST(PrometheusExporterTest, SanitizeLabelName) { + auto exporter = absl::make_unique(); + auto counter = absl::WrapUnique(monitoring::Counter<1>::New( + "/test/path/total", "A counter.", "my-name+1")); + counter->GetCell("abc")->IncrementBy(2); + + string http_page; + Status status = exporter->GeneratePage(&http_page); + + string expected_result = + absl::StrJoin({"# TYPE :test:path:total counter", + ":test:path:total{my_name_1=\"abc\"} 2"}, + "\n"); + absl::StrAppend(&expected_result, "\n"); + EXPECT_EQ(http_page, expected_result); +} + +TEST(PrometheusExporterTest, SanitizeMetricName) { + auto exporter = absl::make_unique(); + auto counter = absl::WrapUnique( + monitoring::Counter<1>::New("0/path-total_count", "A counter.", "name")); + counter->GetCell("abc")->IncrementBy(2); + + string http_page; + Status status = exporter->GeneratePage(&http_page); + + string expected_result = + absl::StrJoin({"# TYPE _0:path:total_count counter", + "_0:path:total_count{name=\"abc\"} 2"}, + "\n"); + absl::StrAppend(&expected_result, "\n"); + EXPECT_EQ(http_page, expected_result); +} + +} // namespace +} // namespace serving +} // namespace tensorflow