From e56a8b0ccf4687030d0e2210f2849e4894568130 Mon Sep 17 00:00:00 2001 From: Tijmen van Nesselrooij Date: Wed, 3 Sep 2025 11:28:32 +0200 Subject: [PATCH] Replace hack with Arc --- src/electricity_api/src/main.rs | 200 +++++++++++++++++--------------- 1 file changed, 105 insertions(+), 95 deletions(-) diff --git a/src/electricity_api/src/main.rs b/src/electricity_api/src/main.rs index cc5bbd9..66dc325 100644 --- a/src/electricity_api/src/main.rs +++ b/src/electricity_api/src/main.rs @@ -8,18 +8,12 @@ mod helpers; #[tokio::main] async fn main() { - let args = CommandLineArgs::parse(); + let args = std::sync::Arc::new(CommandLineArgs::parse()); colog::init(); log::info!("Electricity API is starting up"); - serve( - leak_string_to_static(&args.database_path), - args.listening_port, - leak_string_to_static(&args.http_header_name_to_validate), - leak_string_to_static(&args.http_header_value_to_validate), - ) - .await; + serve(&args).await; } #[derive(clap::Parser)] @@ -37,110 +31,126 @@ struct CommandLineArgs { http_header_value_to_validate: String, } -// This is really stupid to do but the alternative seems to be that every -// endpoint needs its own complete copy of all the configuration values? Weird. -fn leak_string_to_static(value: &String) -> &'static str { - Box::leak(value.clone().into_boxed_str()) -} - -async fn serve( - database_path: &'static str, - listening_port: u16, - http_header_name_to_validate: &'static str, - http_header_value_to_validate: &'static str, -) { +async fn serve(configuration: &std::sync::Arc) { let day = warp::get() .and(warp::path("day")) .and(warp::query::>()) - .and(warp::header::(http_header_name_to_validate)) - .map(|query: HashMap, header_value: String| { - if !is_valid_header(&header_value, http_header_value_to_validate) { - log::info!( - "Access requested to /day with invalid header value {}", - header_value - ); - return Response::builder() - .status(403) - .body(String::from("Forbidden")); - } - - match helpers::try_parse_query_date(query.get("date")) { - Some(date) => { - let json = get_day_power_json(&date, database_path); + .and(warp::header::headers_cloned()) + .and(warp::any().map({ + let configuration = configuration.clone(); + move || configuration.clone() + })) + .map( + |query: HashMap, + headers, + configuration: std::sync::Arc| { + if !has_required_header( + &headers, + &configuration.http_header_name_to_validate, + &configuration.http_header_value_to_validate, + ) { + log::info!("Access requested to /day with invalid header value"); return Response::builder() - .header("Content-Type", "application/json") - .body(json); + .status(403) + .body(String::from("Forbidden")); } - _ => Response::builder() - .status(400) - .body(String::from("Unsupported \"date\" param in query.")), - } - }); + + match helpers::try_parse_query_date(query.get("date")) { + Some(date) => { + let json = get_day_power_json(&date, &configuration.database_path); + return Response::builder() + .header("Content-Type", "application/json") + .body(json); + } + _ => Response::builder() + .status(400) + .body(String::from("Unsupported \"date\" param in query.")), + } + }, + ); let days = warp::get() .and(warp::path("days")) .and(warp::query::>()) - .and(warp::header::(http_header_name_to_validate)) - .map(|query: HashMap, header_value: String| { - if !is_valid_header(&header_value, http_header_value_to_validate) { - log::info!( - "Access requested to /days with invalid header value {}", - header_value - ); + .and(warp::header::headers_cloned()) + .and(warp::any().map({ + let configuration = configuration.clone(); + move || configuration.clone() + })) + .map( + |query: HashMap, + headers, + configuration: std::sync::Arc| { + if !has_required_header( + &headers, + &configuration.http_header_name_to_validate, + &configuration.http_header_value_to_validate, + ) { + log::info!("Access requested to /days with invalid header value"); + return Response::builder() + .status(403) + .body(String::from("Forbidden")); + } + + let maybe_start = helpers::try_parse_query_date(query.get("start")); + if maybe_start.is_none() { + return Response::builder() + .status(400) + .body(String::from("Unsupported \"start\" param in query.")); + } + + let maybe_stop = helpers::try_parse_query_date(query.get("stop")); + if maybe_stop.is_none() { + return Response::builder() + .status(400) + .body(String::from("Unsupported \"stop\" param in query.")); + } + + let start = maybe_start.unwrap(); + let stop = maybe_stop.unwrap(); + if start > stop { + return Response::builder().status(400).body(String::from( + "Param \"start\" must be smaller than or equal to param \"stop\" in query.", + )); + } + + let maybe_day_before_start = start.checked_sub_days(chrono::Days::new(1)); + if maybe_day_before_start.is_none() { + return Response::builder() + .status(400) + .body(String::from("Param \"start\" in query is too early.")); + } + + let day_before_start = maybe_day_before_start.unwrap(); + let json = + get_days_power_json(&day_before_start, &stop, &configuration.database_path); return Response::builder() - .status(403) - .body(String::from("Forbidden")); - } - - let maybe_start = helpers::try_parse_query_date(query.get("start")); - if maybe_start.is_none() { - return Response::builder() - .status(400) - .body(String::from("Unsupported \"start\" param in query.")); - } - - let maybe_stop = helpers::try_parse_query_date(query.get("stop")); - if maybe_stop.is_none() { - return Response::builder() - .status(400) - .body(String::from("Unsupported \"stop\" param in query.")); - } - - let start = maybe_start.unwrap(); - let stop = maybe_stop.unwrap(); - if start > stop { - return Response::builder().status(400).body(String::from( - "Param \"start\" must be smaller than or equal to param \"stop\" in query.", - )); - } - - let maybe_day_before_start = start.checked_sub_days(chrono::Days::new(1)); - if maybe_day_before_start.is_none() { - return Response::builder() - .status(400) - .body(String::from("Param \"start\" in query is too early.")); - } - - let day_before_start = maybe_day_before_start.unwrap(); - let json = get_days_power_json(&day_before_start, &stop, database_path); - return Response::builder() - .header("Content-Type", "application/json") - .body(json); - }); + .header("Content-Type", "application/json") + .body(json); + }, + ); warp::serve(day.or(days)) - .run(([127, 0, 0, 1], listening_port)) + .run(([127, 0, 0, 1], configuration.listening_port)) .await; } -fn is_valid_header(header_value: &str, allowed_values: &str) -> bool { - for value in allowed_values.split(',') { - if *header_value == *value { - return true; +fn has_required_header( + headers: &warp::http::HeaderMap, + expected_header_name: &str, + allowed_header_values: &str, +) -> bool { + match headers.iter().find(|h| *h.0 == *expected_header_name) { + Some((_, header_value)) => { + for value in allowed_header_values.split(',') { + if *header_value == *value { + return true; + } + } + return false; } + _ => return false, } - - false } fn get_day_power_json(date: &chrono::NaiveDate, database_path: &str) -> String {