Skip to content

Commit

Permalink
http(serve): move state outside of serve function, use Extension for …
Browse files Browse the repository at this point in the history
…examples due to sqlx compile time checking
  • Loading branch information
Nick Miller committed Aug 28, 2023
1 parent ebaddea commit 8479839
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 20 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ name = "demo"
name = "echo"

[dependencies]
axum = { version = "0.6.20", features = ["json"] }
axum = { version = "0.6.20", features = ["json", "macros"] }
clap = { version = "4", features = ["derive", "env"] }
once_cell = "1.18"
prometheus = "0.13"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.32", features = ["full"] }
tower = "0.4"
tower-http = { version = "0.4", features = ["cors", "trace", "map-request-body"] }
tower-http = { version = "0.4", features = ["cors", "trace", "map-request-body", "util"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["fmt", "std", "json", "env-filter"] }

Expand Down
13 changes: 7 additions & 6 deletions examples/demo/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// SPDX-License-Identifier: AGPL-3.0-or-later

use servus::axum::{
extract::{self, State},
extract::{self, Extension},
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Expand All @@ -25,6 +25,7 @@ struct AppConfig {
response: String,
}

#[derive(Clone)]
struct AppState {
pool: sqlx::postgres::PgPool,
}
Expand Down Expand Up @@ -56,13 +57,13 @@ async fn main() -> anyhow::Result<()> {

let router = Router::new()
.route("/message", post(post_message))
.route("/message/all", get(get_messages));
.route("/message/all", get(get_messages))
.layer(Extension(state));

servus::http::serve(
config.servus.http_address,
Some(config.servus.metrics_address),
router,
state,
)
.await;

Expand All @@ -76,7 +77,7 @@ struct Message {
}

async fn post_message(
State(state): State<Arc<AppState>>,
Extension(state): Extension<Arc<AppState>>,
extract::Json(payload): extract::Json<Message>,
) -> StatusCode {
info!(
Expand All @@ -88,7 +89,7 @@ async fn post_message(
if let Err(e) = sqlx::query!(
"INSERT INTO guestbook (author, message) VALUES ($1, $2)",
payload.author,
payload.message
payload.message,
)
.execute(&state.pool)
.await
Expand All @@ -100,7 +101,7 @@ async fn post_message(
StatusCode::OK
}

async fn get_messages(State(state): State<Arc<AppState>>) -> impl IntoResponse {
async fn get_messages(Extension(state): Extension<Arc<AppState>>) -> impl IntoResponse {
info!(message = "got get messages request!");

let q = sqlx::query!("select * from guestbook")
Expand Down
2 changes: 1 addition & 1 deletion examples/echo/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async fn main() -> anyhow::Result<()> {
// Note, we pass the `metrics_address` parameter value as `None` to imply we don't want to
// start the metrics server. Also, the `state` parameter is the unit type `()`, meaning we have
// no global state and all handlers are stateless.
servus::http::serve(config.servus.http_address, None, router, ()).await;
servus::http::serve(config.servus.http_address, None, router).await;

Ok(())
}
Expand Down
18 changes: 7 additions & 11 deletions src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,14 @@ use tracing::{error, Level};
///
/// Both the application server and metrics server will respond to a CTRL-C shutdown signal and
/// terminate gracefully.
pub async fn serve<S>(
pub async fn serve(
http_address: SocketAddr,
metrics_address: Option<SocketAddr>,
router: Router<S>,
state: S,
) where
S: Send + Sync + Clone + 'static,
{
router: Router<()>,
) {
// create primary application router and server
// applying handler state if we have it, and default metrics/tracing middleware
let r = router
.with_state(state)
let router = router
.route_layer(middleware::from_fn(metrics::middleware)) // only record matched routes
.layer(
TraceLayer::new_for_http().make_span_with(
Expand All @@ -54,17 +50,17 @@ pub async fn serve<S>(
);

let app = Server::bind(&http_address)
.serve(r.into_make_service())
.serve(router.into_make_service())
.with_graceful_shutdown(shutdown_signal());

if let Some(metrics_address) = metrics_address {
// create metrics router and server, also used for healthcheck
let r = Router::new()
let router = Router::new()
.route("/metrics", routing::get(metrics::handler))
.route("/health", routing::get(health));

let metrics = Server::bind(&metrics_address)
.serve(r.into_make_service())
.serve(router.into_make_service())
.with_graceful_shutdown(shutdown_signal());

// spawn each server instance (so they can be scheduled on separate threads as necessary)
Expand Down

0 comments on commit 8479839

Please sign in to comment.