Replace hack with Arc

This commit is contained in:
2025-09-03 11:28:32 +02:00
parent 90544b5b07
commit e56a8b0ccf

View File

@@ -8,18 +8,12 @@ mod helpers;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let args = CommandLineArgs::parse(); let args = std::sync::Arc::new(CommandLineArgs::parse());
colog::init(); colog::init();
log::info!("Electricity API is starting up"); log::info!("Electricity API is starting up");
serve( serve(&args).await;
leak_string_to_static(&args.database_path),
args.listening_port,
leak_string_to_static(&args.http_header_name_to_validate),
leak_string_to_static(&args.http_header_value_to_validate),
)
.await;
} }
#[derive(clap::Parser)] #[derive(clap::Parser)]
@@ -37,110 +31,126 @@ struct CommandLineArgs {
http_header_value_to_validate: String, http_header_value_to_validate: String,
} }
// This is really stupid to do but the alternative seems to be that every async fn serve(configuration: &std::sync::Arc<CommandLineArgs>) {
// endpoint needs its own complete copy of all the configuration values? Weird.
fn leak_string_to_static(value: &String) -> &'static str {
Box::leak(value.clone().into_boxed_str())
}
async fn serve(
database_path: &'static str,
listening_port: u16,
http_header_name_to_validate: &'static str,
http_header_value_to_validate: &'static str,
) {
let day = warp::get() let day = warp::get()
.and(warp::path("day")) .and(warp::path("day"))
.and(warp::query::<HashMap<String, String>>()) .and(warp::query::<HashMap<String, String>>())
.and(warp::header::<String>(http_header_name_to_validate)) .and(warp::header::headers_cloned())
.map(|query: HashMap<String, String>, header_value: String| { .and(warp::any().map({
if !is_valid_header(&header_value, http_header_value_to_validate) { let configuration = configuration.clone();
log::info!( move || configuration.clone()
"Access requested to /day with invalid header value {}", }))
header_value .map(
); |query: HashMap<String, String>,
return Response::builder() headers,
.status(403) configuration: std::sync::Arc<CommandLineArgs>| {
.body(String::from("Forbidden")); if !has_required_header(
} &headers,
&configuration.http_header_name_to_validate,
match helpers::try_parse_query_date(query.get("date")) { &configuration.http_header_value_to_validate,
Some(date) => { ) {
let json = get_day_power_json(&date, database_path); log::info!("Access requested to /day with invalid header value");
return Response::builder() return Response::builder()
.header("Content-Type", "application/json") .status(403)
.body(json); .body(String::from("Forbidden"));
} }
_ => Response::builder()
.status(400) match helpers::try_parse_query_date(query.get("date")) {
.body(String::from("Unsupported \"date\" param in query.")), Some(date) => {
} let json = get_day_power_json(&date, &configuration.database_path);
}); return Response::builder()
.header("Content-Type", "application/json")
.body(json);
}
_ => Response::builder()
.status(400)
.body(String::from("Unsupported \"date\" param in query.")),
}
},
);
let days = warp::get() let days = warp::get()
.and(warp::path("days")) .and(warp::path("days"))
.and(warp::query::<HashMap<String, String>>()) .and(warp::query::<HashMap<String, String>>())
.and(warp::header::<String>(http_header_name_to_validate)) .and(warp::header::headers_cloned())
.map(|query: HashMap<String, String>, header_value: String| { .and(warp::any().map({
if !is_valid_header(&header_value, http_header_value_to_validate) { let configuration = configuration.clone();
log::info!( move || configuration.clone()
"Access requested to /days with invalid header value {}", }))
header_value .map(
); |query: HashMap<String, String>,
headers,
configuration: std::sync::Arc<CommandLineArgs>| {
if !has_required_header(
&headers,
&configuration.http_header_name_to_validate,
&configuration.http_header_value_to_validate,
) {
log::info!("Access requested to /days with invalid header value");
return Response::builder()
.status(403)
.body(String::from("Forbidden"));
}
let maybe_start = helpers::try_parse_query_date(query.get("start"));
if maybe_start.is_none() {
return Response::builder()
.status(400)
.body(String::from("Unsupported \"start\" param in query."));
}
let maybe_stop = helpers::try_parse_query_date(query.get("stop"));
if maybe_stop.is_none() {
return Response::builder()
.status(400)
.body(String::from("Unsupported \"stop\" param in query."));
}
let start = maybe_start.unwrap();
let stop = maybe_stop.unwrap();
if start > stop {
return Response::builder().status(400).body(String::from(
"Param \"start\" must be smaller than or equal to param \"stop\" in query.",
));
}
let maybe_day_before_start = start.checked_sub_days(chrono::Days::new(1));
if maybe_day_before_start.is_none() {
return Response::builder()
.status(400)
.body(String::from("Param \"start\" in query is too early."));
}
let day_before_start = maybe_day_before_start.unwrap();
let json =
get_days_power_json(&day_before_start, &stop, &configuration.database_path);
return Response::builder() return Response::builder()
.status(403) .header("Content-Type", "application/json")
.body(String::from("Forbidden")); .body(json);
} },
);
let maybe_start = helpers::try_parse_query_date(query.get("start"));
if maybe_start.is_none() {
return Response::builder()
.status(400)
.body(String::from("Unsupported \"start\" param in query."));
}
let maybe_stop = helpers::try_parse_query_date(query.get("stop"));
if maybe_stop.is_none() {
return Response::builder()
.status(400)
.body(String::from("Unsupported \"stop\" param in query."));
}
let start = maybe_start.unwrap();
let stop = maybe_stop.unwrap();
if start > stop {
return Response::builder().status(400).body(String::from(
"Param \"start\" must be smaller than or equal to param \"stop\" in query.",
));
}
let maybe_day_before_start = start.checked_sub_days(chrono::Days::new(1));
if maybe_day_before_start.is_none() {
return Response::builder()
.status(400)
.body(String::from("Param \"start\" in query is too early."));
}
let day_before_start = maybe_day_before_start.unwrap();
let json = get_days_power_json(&day_before_start, &stop, database_path);
return Response::builder()
.header("Content-Type", "application/json")
.body(json);
});
warp::serve(day.or(days)) warp::serve(day.or(days))
.run(([127, 0, 0, 1], listening_port)) .run(([127, 0, 0, 1], configuration.listening_port))
.await; .await;
} }
fn is_valid_header(header_value: &str, allowed_values: &str) -> bool { fn has_required_header(
for value in allowed_values.split(',') { headers: &warp::http::HeaderMap,
if *header_value == *value { expected_header_name: &str,
return true; allowed_header_values: &str,
) -> bool {
match headers.iter().find(|h| *h.0 == *expected_header_name) {
Some((_, header_value)) => {
for value in allowed_header_values.split(',') {
if *header_value == *value {
return true;
}
}
return false;
} }
_ => return false,
} }
false
} }
fn get_day_power_json(date: &chrono::NaiveDate, database_path: &str) -> String { fn get_day_power_json(date: &chrono::NaiveDate, database_path: &str) -> String {