diff --git a/sdk/include/opentelemetry/sdk/metrics/meter.h b/sdk/include/opentelemetry/sdk/metrics/meter.h index ef0d5ffb4e..b2c86a7e49 100644 --- a/sdk/include/opentelemetry/sdk/metrics/meter.h +++ b/sdk/include/opentelemetry/sdk/metrics/meter.h @@ -120,6 +120,7 @@ class Meter final : public opentelemetry::metrics::Meter InstrumentDescriptor &instrument_descriptor); std::unique_ptr RegisterAsyncMetricStorage( InstrumentDescriptor &instrument_descriptor); + opentelemetry::common::SpinLockMutex storage_lock_; }; } // namespace metrics } // namespace sdk diff --git a/sdk/src/metrics/meter.cc b/sdk/src/metrics/meter.cc index 27b8044cdb..8d36509c34 100644 --- a/sdk/src/metrics/meter.cc +++ b/sdk/src/metrics/meter.cc @@ -208,6 +208,7 @@ const sdk::instrumentationscope::InstrumentationScope *Meter::GetInstrumentation std::unique_ptr Meter::RegisterSyncMetricStorage( InstrumentDescriptor &instrument_descriptor) { + std::lock_guard guard(storage_lock_); auto ctx = meter_context_.lock(); if (!ctx) { @@ -251,6 +252,7 @@ std::unique_ptr Meter::RegisterSyncMetricStorage( std::unique_ptr Meter::RegisterAsyncMetricStorage( InstrumentDescriptor &instrument_descriptor) { + std::lock_guard guard(storage_lock_); auto ctx = meter_context_.lock(); if (!ctx) { @@ -302,6 +304,7 @@ std::vector Meter::Collect(CollectorHandle *collector, << "The metric context is invalid"); return std::vector{}; } + std::lock_guard guard(storage_lock_); for (auto &metric_storage : storage_registry_) { metric_storage.second->Collect(collector, ctx->GetCollectors(), ctx->GetSDKStartTime(), diff --git a/sdk/test/metrics/meter_test.cc b/sdk/test/metrics/meter_test.cc index 9108297624..77fde74d03 100644 --- a/sdk/test/metrics/meter_test.cc +++ b/sdk/test/metrics/meter_test.cc @@ -30,14 +30,15 @@ class MockMetricReader : public MetricReader namespace { -nostd::shared_ptr InitMeter(MetricReader **metricReaderPtr) +nostd::shared_ptr InitMeter(MetricReader **metricReaderPtr, + std::string meter_name = "meter_name") { static std::shared_ptr provider(new MeterProvider()); std::unique_ptr metric_reader(new MockMetricReader()); *metricReaderPtr = metric_reader.get(); auto p = std::static_pointer_cast(provider); p->AddMetricReader(std::move(metric_reader)); - auto meter = provider->GetMeter("meter_name"); + auto meter = provider->GetMeter(meter_name); return meter; } } // namespace @@ -70,6 +71,65 @@ TEST(MeterTest, BasicAsyncTests) } return true; }); + observable_counter->RemoveCallback(asyc_generate_measurements, nullptr); +} + +constexpr static unsigned MAX_THREADS = 25; +constexpr static unsigned MAX_ITERATIONS_MT = 1000; + +TEST(MeterTest, StressMultiThread) +{ + MetricReader *metric_reader_ptr = nullptr; + auto meter = InitMeter(&metric_reader_ptr, "stress_test_meter"); + std::atomic threadCount(0); + size_t numIterations = MAX_ITERATIONS_MT; + std::atomic do_collect{false}, do_sync_create{true}, do_async_create{false}; + std::vector> + observable_instruments; + std::vector meter_operation_threads; + size_t instrument_id = 0; + while (numIterations--) + { + for (size_t i = 0; i < MAX_THREADS; i++) + { + if (threadCount++ < MAX_THREADS) + { + auto t = std::thread([&]() { + std::this_thread::yield(); + if (do_sync_create.exchange(false)) + { + std::string instrument_name = "test_couter_" + std::to_string(instrument_id); + meter->CreateLongCounter(instrument_name, "", ""); + do_async_create.store(true); + instrument_id++; + } + if (do_async_create.exchange(false)) + { + std::cout << "\n creating async thread " << std::to_string(numIterations); + auto observable_instrument = + meter->CreateLongObservableGauge("test_gauge_" + std::to_string(instrument_id)); + observable_instrument->AddCallback(asyc_generate_measurements, nullptr); + observable_instruments.push_back(std::move(observable_instrument)); + do_collect.store(true); + instrument_id++; + } + if (do_collect.exchange(false)) + { + metric_reader_ptr->Collect([](ResourceMetrics &metric_data) { return true; }); + do_sync_create.store(true); + } + }); + meter_operation_threads.push_back(std::move(t)); + } + } + } + for (auto &t : meter_operation_threads) + { + if (t.joinable()) + { + t.join(); + } + } } #endif