diff --git a/kitchen/src/web/metrics.rs b/kitchen/src/web/metrics.rs index 52f6c53..47bfd4d 100644 --- a/kitchen/src/web/metrics.rs +++ b/kitchen/src/web/metrics.rs @@ -11,78 +11,99 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::{ - atomic::{AtomicU64, Ordering}, - Arc, Mutex, +//! A [metrics] powered [TraceLayer] that works with any [Tower](https://crates.io/crates/tower) middleware. +use axum::http::{Request, Response}; +use metrics::{histogram, increment_counter, Label}; +use std::{ + marker::PhantomData, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, Mutex, + }, }; - -use axum::{body::Bytes, http::Request, http::Response}; -use metrics::{Key, Label, Recorder}; -use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusRecorder}; use tower_http::{ classify::{ServerErrorsAsFailures, SharedClassifier}, trace::{ - DefaultMakeSpan, DefaultOnEos, OnBodyChunk, OnEos, OnFailure, OnRequest, OnResponse, - TraceLayer, + DefaultMakeSpan, DefaultOnEos, OnBodyChunk, OnFailure, OnRequest, OnResponse, TraceLayer, }, }; +use tracing; -// We want to track requeste count, request latency, request size minimally. - -pub type MetricsTraceLayer = TraceLayer< +/// A Metrics Trace Layer using a [MetricsRecorder]. +/// +/// The layer will record 4 different metrics: +/// +/// * http_request_counter +/// * http_request_failure_counter +/// * http_request_size_bytes_hist +/// * http_request_request_time_micros_hist +/// +/// Each of the metrics are labled by host, method, and path +pub type MetricsTraceLayer = TraceLayer< SharedClassifier, DefaultMakeSpan, - MetricsRecorder, - MetricsRecorder, - MetricsRecorder, + MetricsRecorder, + MetricsRecorder, + MetricsRecorder, DefaultOnEos, - MetricsRecorder, + MetricsRecorder, >; -pub fn get_recorder() -> PrometheusRecorder { - let builder = PrometheusBuilder::new(); - builder.build_recorder() -} - -#[derive(Clone)] -pub struct MetricsRecorder { - rec: Arc, +/// Holds the state required for recording metrics on a given request. +pub struct MetricsRecorder +where + F: Fn(&B) -> u64, +{ labels: Arc>>, size: Arc, + chunk_len: Arc, + _phantom: PhantomData, } -impl MetricsRecorder { - pub fn new_with_rec(rec: Arc) -> Self { +impl Clone for MetricsRecorder +where + F: Fn(&B) -> u64, +{ + fn clone(&self) -> Self { Self { - rec, - labels: Arc::new(Mutex::new(Vec::new())), - size: Arc::new(AtomicU64::new(0)), + labels: self.labels.clone(), + size: self.size.clone(), + chunk_len: self.chunk_len.clone(), + _phantom: self._phantom.clone(), } } } -impl OnBodyChunk for MetricsRecorder { - fn on_body_chunk( - &mut self, - chunk: &Bytes, - _latency: std::time::Duration, - _span: &tracing::Span, - ) { - let _ = self.size.fetch_add(chunk.len() as u64, Ordering::SeqCst); +impl MetricsRecorder +where + F: Fn(&B) -> u64, +{ + /// Construct a new [MetricsRecorder] using the installed [Recorder]. + pub fn new(f: F) -> Self { + Self { + labels: Arc::new(Mutex::new(Vec::new())), + size: Arc::new(AtomicU64::new(0)), + chunk_len: Arc::new(f), + _phantom: PhantomData, + } } } -impl OnEos for MetricsRecorder { - fn on_eos( - self, - _trailers: Option<&axum::http::HeaderMap>, - _stream_duration: std::time::Duration, - _span: &tracing::Span, - ) { +impl OnBodyChunk for MetricsRecorder +where + F: Fn(&B) -> u64, +{ + fn on_body_chunk(&mut self, chunk: &B, _latency: std::time::Duration, _span: &tracing::Span) { + let _ = self + .size + .fetch_add(self.chunk_len.as_ref()(chunk), Ordering::SeqCst); } } -impl OnFailure for MetricsRecorder { +impl OnFailure for MetricsRecorder +where + F: Fn(&B) -> u64, +{ fn on_failure( &mut self, _failure_classification: FailureClass, @@ -90,30 +111,31 @@ impl OnFailure for MetricsRecorder { _span: &tracing::Span, ) { let labels = self.labels.lock().expect("Failed to unlock labels").clone(); - self.rec - .as_ref() - .register_histogram(&Key::from_parts("http_request_failure_counter", labels)); + increment_counter!("http_request_failure_counter", labels); } } -impl OnResponse for MetricsRecorder { +impl OnResponse for MetricsRecorder +where + F: Fn(&B) -> u64, +{ fn on_response( self, - _response: &Response, + _response: &Response, latency: std::time::Duration, _span: &tracing::Span, ) { let labels = self.labels.lock().expect("Failed to unlock labels").clone(); - self.rec - .as_ref() - .register_histogram(&Key::from_parts("http_request_time_micros", labels.clone())) - // If we somehow end up having requests overflow from u128 into f64 then we have - // much bigger problems than this cast. - .record(latency.as_micros() as f64); - self.rec - .as_ref() - .register_histogram(&Key::from_parts("http_request_size_bytes", labels)) - .record(self.size.as_ref().load(Ordering::SeqCst) as f64); + histogram!( + "http_request_time_micros_hist", + latency.as_micros() as f64, + labels.clone() + ); + histogram!( + "http_request_size_bytes_hist", + self.size.as_ref().load(Ordering::SeqCst) as f64, + labels + ) } } @@ -125,9 +147,11 @@ fn make_request_lables(path: String, host: String, method: String) -> Vec