appview/router/src/xrpc.rs

191 lines
4.5 KiB
Rust
Raw Normal View History

use crate::Endpoint;
use std::{
collections::HashMap,
pin::Pin,
future::Future,
};
use atproto::types::Nsid;
use axum::{
extract::{
Json,
Query,
Request,
FromRequest,
FromRequestParts,
rejection::QueryRejection,
},
body::Bytes,
routing::{
get,
post,
method_routing::MethodRouter,
},
http::{
StatusCode,
request::Parts,
},
Router as axumRouter,
};
use serde_json::{Value, json};
enum Path {
Nsid(Nsid),
NotImplemented,
}
pub struct XrpcEndpoint {
path: Path,
resolver: MethodRouter,
}
pub type Response = (StatusCode, Json<Value>);
pub fn error(code: StatusCode, error: &str, message: &str) -> Response {
(
code,
Json(json!({
"error": error,
"message": message
}))
)
}
pub fn response(code: StatusCode, message: &str) -> Response {
error(code, "", message)
}
2025-08-29 16:54:32 -07:00
pub struct QueryInput<S = ()>
where S: Clone + Send + Sync + 'static,
{
pub parameters: HashMap<String, String>,
2025-08-29 16:54:32 -07:00
pub state: S,
}
impl<S> FromRequestParts<S> for QueryInput
where
S: Send + Sync,
{
type Rejection = Response;
2025-08-29 16:54:32 -07:00
async fn from_request_parts(parts: &mut Parts, state: &S)
-> Result<Self, Self::Rejection> {
let query_params: Result<Query<HashMap<String, String>>, QueryRejection> = Query::try_from_uri(&parts.uri);
match query_params {
2025-08-29 16:54:32 -07:00
Ok(p) => Ok(QueryInput { parameters: p.0, state }),
Err(e) => Err(error(StatusCode::BAD_REQUEST, "Bad Parameters", &e.body_text())),
}
}
}
#[derive(Debug)]
2025-08-29 16:54:32 -07:00
pub struct ProcedureInput<J, S = ()>
where S: Clone + Send + Sync + 'static,
{
pub parameters: HashMap<String, String>,
pub input: J,
2025-08-29 16:54:32 -07:00
pub state: S,
}
2025-08-29 16:54:32 -07:00
impl<J, S> FromRequest<S> for ProcedureInput<J, S>
where
J: for<'de> serde::Deserialize<'de> + Send + 'static,
Bytes: FromRequest<S>,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: Request, state: &S)
-> Result<Self, Self::Rejection> {
let parameters: HashMap<String, String> = Query::try_from_uri(req.uri())
.map(|p| p.0)
.map_err(|e| error(StatusCode::BAD_REQUEST, "Bad Paramters", &e.body_text()))?;
let input: J = Json::<J>::from_request(req, state).await
.map(|Json(v)| v)
.map_err(|e| error(StatusCode::BAD_REQUEST, "Bad Parameters", &e.body_text()))?;
2025-08-29 16:54:32 -07:00
Ok(ProcedureInput { parameters, input, state })
}
}
pub trait XrpcHandler<Input>: Send + Sync + 'static {
fn call(&self, input: Input)
-> Pin<Box<dyn Future<Output = Response> + Send>>;
}
impl<F, Fut> XrpcHandler<QueryInput> for F
where
F: Fn(QueryInput) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
fn call(&self, input: QueryInput)
-> Pin<Box<dyn Future<Output = Response>+ Send>> {
Box::pin((self)(input))
}
}
impl<J, F, Fut> XrpcHandler<ProcedureInput<J>> for F
where
F: Fn(ProcedureInput<J>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
fn call(&self, input: ProcedureInput<J>)
-> Pin<Box<dyn Future<Output = Response>+ Send>> {
Box::pin((self)(input))
}
}
impl XrpcEndpoint {
2025-08-29 16:54:32 -07:00
pub fn new_query<Q, S>(nsid: Nsid, query: Q) -> Self
where
Q: XrpcHandler<QueryInput> + Clone
{
XrpcEndpoint {
path: Path::Nsid(nsid),
2025-08-29 16:54:32 -07:00
resolver: get(async move | mut parts: Parts, state: &S | -> Response {
match QueryInput<S>::from_request_parts(&mut parts, state).await {
Ok(qi) => query.call(qi).await,
Err(e) => e
}
})
}
}
2025-08-29 16:54:32 -07:00
pub fn new_procedure<P, J, S>(nsid: Nsid, procedure: P) -> Self
where
2025-08-29 16:54:32 -07:00
P: XrpcHandler<ProcedureInput<J, S>> + Clone,
J: for<'de> serde::Deserialize<'de> + Send + 'static,
{
XrpcEndpoint {
path: Path::Nsid(nsid),
2025-08-29 16:54:32 -07:00
resolver: post(async move | req: Request, state: &S | -> Response {
match ProcedureInput::<J, S>::from_request(req, &state).await {
Ok(pi) => procedure.call(pi).await,
Err(e) => e
}
})
}
}
pub fn not_implemented() -> Self {
let resolver = (
StatusCode::NOT_IMPLEMENTED,
Json(json!({
"error": "MethodNotImplemented",
"message": "Method Not Implemented"
}))
);
XrpcEndpoint {
path: Path::NotImplemented,
resolver: get(resolver.clone()).post(resolver),
}
}
}
impl Endpoint for XrpcEndpoint {
fn add_to_router(self, router: axumRouter) -> axumRouter {
let path = match self.path {
Path::Nsid(nsid) => &("/xrpc/".to_owned() + &nsid.to_string()),
Path::NotImplemented => "/xrpc/{*nsid}",
};
router.route(path, self.resolver)
}
}