Replace hack with Arc

This commit is contained in:
2025-09-03 11:28:32 +02:00
parent 90544b5b07
commit e56a8b0ccf

View File

@@ -8,18 +8,12 @@ mod helpers;
#[tokio::main]
async fn main() {
let args = CommandLineArgs::parse();
let args = std::sync::Arc::new(CommandLineArgs::parse());
colog::init();
log::info!("Electricity API is starting up");
serve(
leak_string_to_static(&args.database_path),
args.listening_port,
leak_string_to_static(&args.http_header_name_to_validate),
leak_string_to_static(&args.http_header_value_to_validate),
)
.await;
serve(&args).await;
}
#[derive(clap::Parser)]
@@ -37,110 +31,126 @@ struct CommandLineArgs {
http_header_value_to_validate: String,
}
// This is really stupid to do but the alternative seems to be that every
// endpoint needs its own complete copy of all the configuration values? Weird.
fn leak_string_to_static(value: &String) -> &'static str {
Box::leak(value.clone().into_boxed_str())
}
async fn serve(
database_path: &'static str,
listening_port: u16,
http_header_name_to_validate: &'static str,
http_header_value_to_validate: &'static str,
) {
async fn serve(configuration: &std::sync::Arc<CommandLineArgs>) {
let day = warp::get()
.and(warp::path("day"))
.and(warp::query::<HashMap<String, String>>())
.and(warp::header::<String>(http_header_name_to_validate))
.map(|query: HashMap<String, String>, header_value: String| {
if !is_valid_header(&header_value, http_header_value_to_validate) {
log::info!(
"Access requested to /day with invalid header value {}",
header_value
);
return Response::builder()
.status(403)
.body(String::from("Forbidden"));
}
match helpers::try_parse_query_date(query.get("date")) {
Some(date) => {
let json = get_day_power_json(&date, database_path);
.and(warp::header::headers_cloned())
.and(warp::any().map({
let configuration = configuration.clone();
move || configuration.clone()
}))
.map(
|query: HashMap<String, String>,
headers,
configuration: std::sync::Arc<CommandLineArgs>| {
if !has_required_header(
&headers,
&configuration.http_header_name_to_validate,
&configuration.http_header_value_to_validate,
) {
log::info!("Access requested to /day with invalid header value");
return Response::builder()
.header("Content-Type", "application/json")
.body(json);
.status(403)
.body(String::from("Forbidden"));
}
_ => Response::builder()
.status(400)
.body(String::from("Unsupported \"date\" param in query.")),
}
});
match helpers::try_parse_query_date(query.get("date")) {
Some(date) => {
let json = get_day_power_json(&date, &configuration.database_path);
return Response::builder()
.header("Content-Type", "application/json")
.body(json);
}
_ => Response::builder()
.status(400)
.body(String::from("Unsupported \"date\" param in query.")),
}
},
);
let days = warp::get()
.and(warp::path("days"))
.and(warp::query::<HashMap<String, String>>())
.and(warp::header::<String>(http_header_name_to_validate))
.map(|query: HashMap<String, String>, header_value: String| {
if !is_valid_header(&header_value, http_header_value_to_validate) {
log::info!(
"Access requested to /days with invalid header value {}",
header_value
);
.and(warp::header::headers_cloned())
.and(warp::any().map({
let configuration = configuration.clone();
move || configuration.clone()
}))
.map(
|query: HashMap<String, String>,
headers,
configuration: std::sync::Arc<CommandLineArgs>| {
if !has_required_header(
&headers,
&configuration.http_header_name_to_validate,
&configuration.http_header_value_to_validate,
) {
log::info!("Access requested to /days with invalid header value");
return Response::builder()
.status(403)
.body(String::from("Forbidden"));
}
let maybe_start = helpers::try_parse_query_date(query.get("start"));
if maybe_start.is_none() {
return Response::builder()
.status(400)
.body(String::from("Unsupported \"start\" param in query."));
}
let maybe_stop = helpers::try_parse_query_date(query.get("stop"));
if maybe_stop.is_none() {
return Response::builder()
.status(400)
.body(String::from("Unsupported \"stop\" param in query."));
}
let start = maybe_start.unwrap();
let stop = maybe_stop.unwrap();
if start > stop {
return Response::builder().status(400).body(String::from(
"Param \"start\" must be smaller than or equal to param \"stop\" in query.",
));
}
let maybe_day_before_start = start.checked_sub_days(chrono::Days::new(1));
if maybe_day_before_start.is_none() {
return Response::builder()
.status(400)
.body(String::from("Param \"start\" in query is too early."));
}
let day_before_start = maybe_day_before_start.unwrap();
let json =
get_days_power_json(&day_before_start, &stop, &configuration.database_path);
return Response::builder()
.status(403)
.body(String::from("Forbidden"));
}
let maybe_start = helpers::try_parse_query_date(query.get("start"));
if maybe_start.is_none() {
return Response::builder()
.status(400)
.body(String::from("Unsupported \"start\" param in query."));
}
let maybe_stop = helpers::try_parse_query_date(query.get("stop"));
if maybe_stop.is_none() {
return Response::builder()
.status(400)
.body(String::from("Unsupported \"stop\" param in query."));
}
let start = maybe_start.unwrap();
let stop = maybe_stop.unwrap();
if start > stop {
return Response::builder().status(400).body(String::from(
"Param \"start\" must be smaller than or equal to param \"stop\" in query.",
));
}
let maybe_day_before_start = start.checked_sub_days(chrono::Days::new(1));
if maybe_day_before_start.is_none() {
return Response::builder()
.status(400)
.body(String::from("Param \"start\" in query is too early."));
}
let day_before_start = maybe_day_before_start.unwrap();
let json = get_days_power_json(&day_before_start, &stop, database_path);
return Response::builder()
.header("Content-Type", "application/json")
.body(json);
});
.header("Content-Type", "application/json")
.body(json);
},
);
warp::serve(day.or(days))
.run(([127, 0, 0, 1], listening_port))
.run(([127, 0, 0, 1], configuration.listening_port))
.await;
}
fn is_valid_header(header_value: &str, allowed_values: &str) -> bool {
for value in allowed_values.split(',') {
if *header_value == *value {
return true;
fn has_required_header(
headers: &warp::http::HeaderMap,
expected_header_name: &str,
allowed_header_values: &str,
) -> bool {
match headers.iter().find(|h| *h.0 == *expected_header_name) {
Some((_, header_value)) => {
for value in allowed_header_values.split(',') {
if *header_value == *value {
return true;
}
}
return false;
}
_ => return false,
}
false
}
fn get_day_power_json(date: &chrono::NaiveDate, database_path: &str) -> String {