diff --git a/shotover/src/transforms/query_counter.rs b/shotover/src/transforms/query_counter.rs index f0d10bf7c..760f6c3c1 100644 --- a/shotover/src/transforms/query_counter.rs +++ b/shotover/src/transforms/query_counter.rs @@ -6,8 +6,10 @@ use crate::transforms::{Transform, TransformBuilder, Wrapper}; use anyhow::Result; use async_trait::async_trait; use metrics::counter; +use metrics::Counter; use serde::Deserialize; use serde::Serialize; +use std::collections::HashMap; use super::DownChainProtocol; use super::TransformContextConfig; @@ -16,6 +18,7 @@ use super::UpChainProtocol; #[derive(Clone)] pub struct QueryCounter { counter_name: &'static str, + query_to_counter: HashMap, } #[derive(Serialize, Deserialize, Debug)] @@ -26,13 +29,23 @@ pub struct QueryCounterConfig { impl QueryCounter { pub fn new(counter_name: String) -> Self { - counter!("shotover_query_count", "name" => counter_name.clone()); + // Leaking here is fine since the builder is created only once during shotover startup. + let counter_name_ref: &'static str = counter_name.leak(); + + // Although not incremented, this counter needs to be created to ensure shotover_query_count is 0 on shotover startup. + counter!("shotover_query_count", "name" => counter_name_ref); QueryCounter { - // Leaking here is fine since the builder is created only once during shotover startup. - counter_name: counter_name.leak(), + counter_name: counter_name_ref, + query_to_counter: HashMap::new(), } } + + fn increment_counter(&mut self, query: String, query_type: &'static str) { + self.query_to_counter.entry(query) + .or_insert_with_key(|query| counter!("shotover_query_count", "name" => self.counter_name, "query" => query.clone(), "type" => query_type)) + .increment(1); + } } impl TransformBuilder for QueryCounter { @@ -57,20 +70,20 @@ impl Transform for QueryCounter { #[cfg(feature = "cassandra")] Some(Frame::Cassandra(frame)) => { for statement in frame.operation.queries() { - counter!("shotover_query_count", "name" => self.counter_name, "query" => statement.short_name(), "type" => "cassandra").increment(1); + self.increment_counter(statement.short_name().to_string(), "cassandra"); } } #[cfg(feature = "redis")] Some(Frame::Redis(frame)) => { if let Some(query_type) = crate::frame::redis::redis_query_name(frame) { - counter!("shotover_query_count", "name" => self.counter_name, "query" => query_type, "type" => "redis").increment(1); + self.increment_counter(query_type, "redis"); } else { - counter!("shotover_query_count", "name" => self.counter_name, "query" => "unknown", "type" => "redis").increment(1); + self.increment_counter("unknown".to_string(), "redis"); } } #[cfg(feature = "kafka")] Some(Frame::Kafka(_)) => { - counter!("shotover_query_count", "name" => self.counter_name, "query" => "unknown", "type" => "kafka").increment(1); + self.increment_counter("unknown".to_string(), "kafka"); } Some(Frame::Dummy) => { // Dummy does not count as a message @@ -80,7 +93,7 @@ impl Transform for QueryCounter { todo!(); } None => { - counter!("shotover_query_count", "name" => self.counter_name, "query" => "unknown", "type" => "none").increment(1) + self.increment_counter("unknown".to_string(), "none"); } } }