diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8304858 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.env +.pds-data/ +target/ +tmp/ + diff --git a/Cargo.lock b/Cargo.lock index 6c97f48..71ddb9f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -70,8 +70,8 @@ name = "api" version = "0.1.0" dependencies = [ "atproto", - "axum", "http 1.3.1", + "router", "serde", "serde_json", "tokio", @@ -79,6 +79,18 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "argon2" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072" +dependencies = [ + "base64ct", + "blake2", + "cpufeatures", + "password-hash", +] + [[package]] name = "async-lock" version = "3.4.0" @@ -123,6 +135,7 @@ dependencies = [ "time", "tracing", "tracing-subscriber", + "unicode-segmentation", ] [[package]] @@ -282,6 +295,15 @@ dependencies = [ "serde", ] +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -293,9 +315,9 @@ dependencies = [ [[package]] name = "bon" -version = "3.6.3" +version = "3.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced38439e7a86a4761f7f7d5ded5ff009135939ecb464a24452eaa4c1696af7d" +checksum = "f61138465baf186c63e8d9b6b613b508cd832cba4ce93cf37ce5f096f91ac1a6" dependencies = [ "bon-macros", "rustversion", @@ -303,9 +325,9 @@ dependencies = [ [[package]] name = "bon-macros" -version = "3.6.3" +version = "3.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce61d2d3844c6b8d31b2353d9f66cf5e632b3e9549583fe3cac2f4f6136725e" +checksum = "40d1dad34aa19bf02295382f08d9bc40651585bd497266831d40ee6296fb49ca" dependencies = [ "darling", "ident_case", @@ -662,6 +684,25 @@ dependencies = [ "serde", ] +[[package]] +name = "entryway" +version = "0.1.0" +dependencies = [ + "argon2", + "async-trait", + "atproto", + "http 1.3.1", + "router", + "serde", + "serde_json", + "sqlx", + "thiserror 2.0.12", + "time", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -1574,6 +1615,17 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "password-hash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -1813,6 +1865,21 @@ dependencies = [ "zstd", ] +[[package]] +name = "router" +version = "0.1.0" +dependencies = [ + "atproto", + "axum", + "bon", + "http 1.3.1", + "serde", + "serde_json", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "rsa" version = "0.9.8" @@ -2708,6 +2775,12 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + [[package]] name = "unsigned-varint" version = "0.8.0" diff --git a/Cargo.toml b/Cargo.toml index c029fb0..55a82e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,13 @@ [workspace] resolver = "3" -members = [ "api", "atproto","db", "ingestor"] +members = [ "api", "atproto", "entryway", "db", "router", "ingestor" ] [workspace.dependencies] async-trait = "0.1.88" atproto = { path = "./atproto" } -serde = "1.0.219" +db = { path = "./db" } +router = { path = "./router" } +serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" sqlx = { version = "0.8.6", features = ["postgres", "runtime-tokio"] } thiserror = "2.0.12" diff --git a/api/Cargo.toml b/api/Cargo.toml index 1fc049c..5dcc6e0 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -5,7 +5,7 @@ edition = "2024" [dependencies] atproto.workspace = true -axum = { version = "0.8.3", features = ["json"] } +router.workspace = true http = "1.3.1" serde.workspace = true serde_json.workspace = true diff --git a/api/src/main.rs b/api/src/main.rs index 46e17ae..13040a6 100644 --- a/api/src/main.rs +++ b/api/src/main.rs @@ -1,4 +1,4 @@ -use crate::router::{ +use router::{ Router, Endpoint, xrpc::{ diff --git a/api/src/router.rs b/api/src/router.rs deleted file mode 100644 index bfa3b17..0000000 --- a/api/src/router.rs +++ /dev/null @@ -1,59 +0,0 @@ -use crate::router::xrpc::{ - XrpcEndpoint, - XrpcHandler, - QueryInput, - ProcedureInput, -}; -use atproto::Nsid; -use axum::Router as AxumRouter; -use core::net::SocketAddr; -use std::net::{IpAddr, Ipv4Addr}; -use tokio::net::TcpListener; - -pub struct Router { - addr: SocketAddr, - router: AxumRouter, -} - -// In case server ever needs to support more than just XRPC -pub enum Endpoint { - Xrpc(XrpcEndpoint), -} -impl Endpoint { - pub fn new_xrpc_query(nsid: Nsid, query: Q) -> Self - where - Q: XrpcHandler + Clone - { - Endpoint::Xrpc(XrpcEndpoint::new_query(nsid,query)) - } - pub fn new_xrpc_procedure

(nsid: Nsid, procedure: P) -> Self - where - P: XrpcHandler + Clone - { - Endpoint::Xrpc(XrpcEndpoint::new_procedure(nsid,procedure)) - } -} - -pub mod xrpc; - -impl Router { - pub fn new() -> Self { - let mut router = AxumRouter::new(); - router = XrpcEndpoint::not_implemented().add_to_router(router); - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127,0,0,1)), 6702); - Router { router, addr } - } - - pub fn add_endpoint(mut self, endpoint: Endpoint) -> Self { - match endpoint { - Endpoint::Xrpc(ep) => self.router = ep.add_to_router(self.router), - }; - self - } - - pub async fn serve(self) { - let listener = TcpListener::bind(self.addr).await.unwrap(); - - axum::serve(listener, self.router).await.unwrap(); - } -} diff --git a/atproto/Cargo.toml b/atproto/Cargo.toml index 3f5445b..0881493 100644 --- a/atproto/Cargo.toml +++ b/atproto/Cargo.toml @@ -13,6 +13,7 @@ time = { version = "0.3.41", features = ["parsing", "formatting"] } tracing-subscriber.workspace = true tracing.workspace = true thiserror.workspace = true +unicode-segmentation = "1.9.0" [features] default = [] diff --git a/atproto/src/error.rs b/atproto/src/error.rs index 1db7106..e320a61 100644 --- a/atproto/src/error.rs +++ b/atproto/src/error.rs @@ -20,8 +20,12 @@ pub enum FormatError { pub enum ParseError { #[error("Time Parse Error: {0}")] Datetime(#[from] time::error::Parse), + #[error("Json Parse Error: {0}")] + Serde(#[from] serde_json::error::Error), #[error("Length of parsed object too long, max: {max:?}, got: {got:?}.")] Length { max: usize, got: usize }, + #[error("Length of parsed object too short, min: {min:?}, got: {got:?}.")] + MinLength { min: usize, got: usize }, #[error("Currently Did is enforced, cannot use handle, {handle:?}")] ForceDid { handle: String }, #[error("Incorrectly formatted")] diff --git a/atproto/src/lexicons/mod.rs b/atproto/src/lexicons/mod.rs index 8107df0..96aefa7 100644 --- a/atproto/src/lexicons/mod.rs +++ b/atproto/src/lexicons/mod.rs @@ -1,3 +1 @@ -// @generated - This file is generated by esquema-codegen (forked from atrium-codegen). DO NOT EDIT. -pub mod record; -pub mod my; +pub mod myspoor; diff --git a/atproto/src/lexicons/my.rs b/atproto/src/lexicons/my.rs deleted file mode 100644 index ac9c6bd..0000000 --- a/atproto/src/lexicons/my.rs +++ /dev/null @@ -1,3 +0,0 @@ -// @generated - This file is generated by esquema-codegen (forked from atrium-codegen). DO NOT EDIT. -//!Definitions for the `my` namespace. -pub mod spoor; diff --git a/atproto/src/lexicons/my/spoor.rs b/atproto/src/lexicons/my/spoor.rs deleted file mode 100644 index 172fa6a..0000000 --- a/atproto/src/lexicons/my/spoor.rs +++ /dev/null @@ -1,4 +0,0 @@ -// @generated - This file is generated by esquema-codegen (forked from atrium-codegen). DO NOT EDIT. -//!Definitions for the `my.spoor` namespace. -pub mod content; -pub mod log; diff --git a/atproto/src/lexicons/my/spoor/log.rs b/atproto/src/lexicons/my/spoor/log.rs deleted file mode 100644 index 87f699e..0000000 --- a/atproto/src/lexicons/my/spoor/log.rs +++ /dev/null @@ -1,16 +0,0 @@ -// @generated - This file is generated by esquema-codegen (forked from atrium-codegen). DO NOT EDIT. -//!Definitions for the `my.spoor.log` namespace. -pub mod activity; -pub mod session; -#[derive(Debug)] -pub struct Activity; -impl atrium_api::types::Collection for Activity { - const NSID: &'static str = "my.spoor.log.activity"; - type Record = activity::Record; -} -#[derive(Debug)] -pub struct Session; -impl atrium_api::types::Collection for Session { - const NSID: &'static str = "my.spoor.log.session"; - type Record = session::Record; -} diff --git a/atproto/src/lexicons/my/spoor/content.rs b/atproto/src/lexicons/myspoor/content.rs similarity index 100% rename from atproto/src/lexicons/my/spoor/content.rs rename to atproto/src/lexicons/myspoor/content.rs diff --git a/atproto/src/lexicons/my/spoor/content/external.rs b/atproto/src/lexicons/myspoor/content/external.rs similarity index 100% rename from atproto/src/lexicons/my/spoor/content/external.rs rename to atproto/src/lexicons/myspoor/content/external.rs diff --git a/atproto/src/lexicons/my/spoor/content/media.rs b/atproto/src/lexicons/myspoor/content/media.rs similarity index 100% rename from atproto/src/lexicons/my/spoor/content/media.rs rename to atproto/src/lexicons/myspoor/content/media.rs diff --git a/atproto/src/lexicons/my/spoor/content/title.rs b/atproto/src/lexicons/myspoor/content/title.rs similarity index 100% rename from atproto/src/lexicons/my/spoor/content/title.rs rename to atproto/src/lexicons/myspoor/content/title.rs diff --git a/atproto/src/lexicons/myspoor/log.rs b/atproto/src/lexicons/myspoor/log.rs new file mode 100644 index 0000000..9a008a8 --- /dev/null +++ b/atproto/src/lexicons/myspoor/log.rs @@ -0,0 +1,11 @@ +use serde::Deserialize; +use crate::types::{BoundString, StrongRef, Uri, Did, Datetime}; + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Session { + pub content: StrongRef, + pub label: Option>, + pub other_participants: Option>, + pub created_at: Datetime, +} diff --git a/atproto/src/lexicons/my/spoor/log/activity.rs b/atproto/src/lexicons/myspoor/log/activity.rs similarity index 100% rename from atproto/src/lexicons/my/spoor/log/activity.rs rename to atproto/src/lexicons/myspoor/log/activity.rs diff --git a/atproto/src/lexicons/my/spoor/log/session.rs b/atproto/src/lexicons/myspoor/log/session.rs similarity index 100% rename from atproto/src/lexicons/my/spoor/log/session.rs rename to atproto/src/lexicons/myspoor/log/session.rs diff --git a/atproto/src/lexicons/myspoor/mod.rs b/atproto/src/lexicons/myspoor/mod.rs new file mode 100644 index 0000000..f2d6b27 --- /dev/null +++ b/atproto/src/lexicons/myspoor/mod.rs @@ -0,0 +1,2 @@ +pub mod content; +pub mod log; diff --git a/atproto/src/lexicons/record.rs b/atproto/src/lexicons/record.rs deleted file mode 100644 index 290d286..0000000 --- a/atproto/src/lexicons/record.rs +++ /dev/null @@ -1,65 +0,0 @@ -// @generated - This file is generated by esquema-codegen (forked from atrium-codegen). DO NOT EDIT. -//!A collection of known record types. -#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)] -#[serde(tag = "$type")] -pub enum KnownRecord { - #[serde(rename = "my.spoor.content.external")] - LexiconsMySpoorContentExternal( - Box, - ), - #[serde(rename = "my.spoor.content.media")] - LexiconsMySpoorContentMedia(Box), - #[serde(rename = "my.spoor.log.activity")] - LexiconsMySpoorLogActivity(Box), - #[serde(rename = "my.spoor.log.session")] - LexiconsMySpoorLogSession(Box), -} -impl From for KnownRecord { - fn from(record: crate::lexicons::my::spoor::content::external::Record) -> Self { - KnownRecord::LexiconsMySpoorContentExternal(Box::new(record)) - } -} -impl From for KnownRecord { - fn from( - record_data: crate::lexicons::my::spoor::content::external::RecordData, - ) -> Self { - KnownRecord::LexiconsMySpoorContentExternal(Box::new(record_data.into())) - } -} -impl From for KnownRecord { - fn from(record: crate::lexicons::my::spoor::content::media::Record) -> Self { - KnownRecord::LexiconsMySpoorContentMedia(Box::new(record)) - } -} -impl From for KnownRecord { - fn from( - record_data: crate::lexicons::my::spoor::content::media::RecordData, - ) -> Self { - KnownRecord::LexiconsMySpoorContentMedia(Box::new(record_data.into())) - } -} -impl From for KnownRecord { - fn from(record: crate::lexicons::my::spoor::log::activity::Record) -> Self { - KnownRecord::LexiconsMySpoorLogActivity(Box::new(record)) - } -} -impl From for KnownRecord { - fn from(record_data: crate::lexicons::my::spoor::log::activity::RecordData) -> Self { - KnownRecord::LexiconsMySpoorLogActivity(Box::new(record_data.into())) - } -} -impl From for KnownRecord { - fn from(record: crate::lexicons::my::spoor::log::session::Record) -> Self { - KnownRecord::LexiconsMySpoorLogSession(Box::new(record)) - } -} -impl From for KnownRecord { - fn from(record_data: crate::lexicons::my::spoor::log::session::RecordData) -> Self { - KnownRecord::LexiconsMySpoorLogSession(Box::new(record_data.into())) - } -} -impl Into for KnownRecord { - fn into(self) -> atrium_api::types::Unknown { - atrium_api::types::TryIntoUnknown::try_into_unknown(&self).unwrap() - } -} diff --git a/atproto/src/lib.rs b/atproto/src/lib.rs index 53c8d32..021b86c 100644 --- a/atproto/src/lib.rs +++ b/atproto/src/lib.rs @@ -1,5 +1,7 @@ -pub mod lexicons; +// pub mod lexicons; pub mod types; pub mod error; #[cfg(feature = "sqlx-support")] pub mod sqlx; + +pub use atrium_api::types::Collection; diff --git a/atproto/src/types.rs b/atproto/src/types.rs index 26ee60d..2e1c387 100644 --- a/atproto/src/types.rs +++ b/atproto/src/types.rs @@ -1,5 +1,20 @@ use crate::error::{Error, ParseError}; +#[macro_export] +macro_rules! basic_deserializer { + ($name:ident) => { + impl<'de> serde::de::Deserialize<'de> for $name { + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + let value: String = serde::de::Deserialize::deserialize(deserializer)?; + value.parse::<$name>().map_err(::custom) + } + } + } +} + macro_rules! basic_string_type { ($name:ident, $regex:literal, $max_len:literal) => { pub struct $name { value: String, } @@ -35,20 +50,26 @@ macro_rules! basic_string_type { }) } } + + basic_deserializer!($name); } } -mod did; -pub use did::Did; -mod cid; -pub use cid::Cid; mod authority; -pub use authority::Authority; +mod bound_string; +mod cid; mod datetime; -pub use datetime::Datetime; +mod did; mod record_key; -pub use record_key::RecordKey; +mod strong_ref; mod uri; +pub use authority::Authority; +pub use bound_string::BoundString; +pub use cid::Cid; +pub use datetime::Datetime; +pub use did::Did; +pub use record_key::RecordKey; +pub use strong_ref::StrongRef; pub use uri::Uri; basic_string_type!(Handle, @@ -63,22 +84,3 @@ basic_string_type!(Tid, r"^[234567abcdefghij][234567abcdefghijklmnopqrstuvwxyz]{12}$", 13 ); - -pub struct StrongRef { - content: T, - cid: Cid, -} - -impl StrongRef { - pub fn get_content(&self) -> &T { - &self.content - } - - pub fn extract_content(self) -> (T, Cid) { - (self.content, self.cid) - } - - pub fn get_cid(&self) -> &Cid { - &self.cid - } -} diff --git a/atproto/src/types/authority.rs b/atproto/src/types/authority.rs index 1c76750..ad3d68c 100644 --- a/atproto/src/types/authority.rs +++ b/atproto/src/types/authority.rs @@ -1,12 +1,16 @@ use crate::{ - types::{Did, Handle}, + types::{ + Did, Handle + }, error::{Error, ParseError}, }; +use serde::Deserialize; use std::{ fmt::{Display, Formatter, Result as FmtResult}, str::FromStr, }; +#[derive(Deserialize)] pub enum Authority { Did(Did), Handle(Handle), diff --git a/atproto/src/types/bound_string.rs b/atproto/src/types/bound_string.rs new file mode 100644 index 0000000..db81bcf --- /dev/null +++ b/atproto/src/types/bound_string.rs @@ -0,0 +1,57 @@ +use unicode_segmentation::UnicodeSegmentation; +use crate::error::{Error, ParseError}; +use std::{ + fmt::{Display, Formatter, Result as FmtResult}, + str::FromStr, +}; + +pub struct BoundString< + const MIN: usize, const MAX: usize> +{ + value: String, +} + +impl BoundString { + fn check_length(s: &str) -> Result<(), Error> { + let grapheme_count: usize = s.graphemes(true).take(MAX + 1).count(); + if grapheme_count > MAX { + return Err(Error::Parse { + err: ParseError::Length { max: MAX, got: grapheme_count }, + object: "String".to_string(), + }); + } + if grapheme_count < MIN { + return Err(Error::Parse { + err: ParseError::MinLength { min: MIN, got: grapheme_count }, + object: "String".to_string(), + }); + } + Ok(()) + } +} + +impl Display for BoundString { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{}", self.value) + } +} + +impl FromStr for BoundString { + type Err = Error; + fn from_str(s: &str) -> Result { + Self::check_length(s)?; + + Ok(BoundString { value: s.to_string() }) + } +} + +impl<'de, const MIN: usize, const MAX: usize> serde::de::Deserialize<'de> + for BoundString { + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + let value: String = serde::de::Deserialize::deserialize(deserializer)?; + value.parse::>().map_err(::custom) + } +} diff --git a/atproto/src/types/cid.rs b/atproto/src/types/cid.rs index 3581ad2..a1d3ff2 100644 --- a/atproto/src/types/cid.rs +++ b/atproto/src/types/cid.rs @@ -1,4 +1,7 @@ -use crate::error::Error; +use crate::{ + basic_deserializer, + error::Error +}; pub struct Cid { value: String, } @@ -14,3 +17,5 @@ impl std::str::FromStr for Cid { Ok(Self { value: s.to_string() }) } } + +basic_deserializer!(Cid); diff --git a/atproto/src/types/datetime.rs b/atproto/src/types/datetime.rs index 6dcf234..bd5583b 100644 --- a/atproto/src/types/datetime.rs +++ b/atproto/src/types/datetime.rs @@ -1,4 +1,7 @@ -use crate::error::{Error, ParseError, FormatError}; +use crate::{ + basic_deserializer, + error::{Error, ParseError, FormatError}, +}; use time::{ UtcDateTime, format_description::well_known::{Rfc3339, Iso8601}, @@ -61,3 +64,5 @@ impl FromStr for Datetime { Ok(Datetime { time: datetime, derived_string: s.to_string() }) } } + +basic_deserializer!(Datetime); diff --git a/atproto/src/types/did.rs b/atproto/src/types/did.rs index 6c91db2..fa9b739 100644 --- a/atproto/src/types/did.rs +++ b/atproto/src/types/did.rs @@ -1,4 +1,7 @@ -use crate::error::{Error, ParseError}; +use crate::{ + basic_deserializer, + error::{Error, ParseError}, +}; use std::{ fmt::{Display, Formatter, Result as FmtResult}, str::FromStr, @@ -31,6 +34,8 @@ impl FromStr for DidMethod { } } +basic_deserializer!(DidMethod); + pub struct Did { method: DidMethod, identifier: String, @@ -70,3 +75,5 @@ impl FromStr for Did { }) } } + +basic_deserializer!(Did); diff --git a/atproto/src/types/record_key.rs b/atproto/src/types/record_key.rs index 7368b53..8542029 100644 --- a/atproto/src/types/record_key.rs +++ b/atproto/src/types/record_key.rs @@ -61,3 +61,5 @@ impl FromStr for RecordKey { Ok(RecordKey::Any(s.to_string())) } } + +basic_deserializer!(RecordKey); diff --git a/atproto/src/types/strong_ref.rs b/atproto/src/types/strong_ref.rs new file mode 100644 index 0000000..9e23ff7 --- /dev/null +++ b/atproto/src/types/strong_ref.rs @@ -0,0 +1,46 @@ +use crate::{ + basic_deserializer, + types::{Cid, Uri}, + error::{Error, ParseError}, +}; + +pub struct StrongRef { + content: T, + cid: Cid, +} + +impl StrongRef { + pub fn from_atrium_api(strong_ref: atrium_api::com::atproto::repo::strong_ref::MainData) -> Result { + let str_cid = serde_json::to_string(&strong_ref.cid).map_err(|e| { + Error::Parse { err: ParseError::Serde(e), object: "Uri".to_string() } + })?; + Ok(Self { + content: strong_ref.uri.parse::()?, + cid: str_cid.parse::()?, + }) + } +} + +impl StrongRef { + pub fn map_content(self, f: F) -> StrongRef + where + F: FnOnce(T) -> U, + { + StrongRef { + content: f(self.content), + cid: self.cid, + } + } + + pub fn get_content(&self) -> &T { + &self.content + } + + pub fn extract_content(self) -> (T, Cid) { + (self.content, self.cid) + } + + pub fn get_cid(&self) -> &Cid { + &self.cid + } +} diff --git a/atproto/src/types/uri.rs b/atproto/src/types/uri.rs index 843ee2b..1fbf423 100644 --- a/atproto/src/types/uri.rs +++ b/atproto/src/types/uri.rs @@ -1,4 +1,5 @@ use crate::{ + basic_deserializer, types::{Did, Authority, Nsid, RecordKey}, error::{Error, ParseError}, }; @@ -33,47 +34,78 @@ impl Display for Uri { impl FromStr for Uri { type Err = Error; fn from_str(s: &str) -> Result { - if s.len() > 8000 { - return Err(Error::Parse { - err: ParseError::Length { max: 8000, got: s.len() }, - object: "Did".to_string(), - }); - } + Self::check_length(s)?; let Some(( _whole, unchecked_authority, unchecked_collection, unchecked_rkey - )) = regex_captures!( + )): Option<(&str, &str, &str, &str)> = regex_captures!( r"/^at:\/\/([\w\.\-_~:]+)(?:\/([\w\.\-_~:]+)(?:)\/([\w\.\-_~:]+))?$/i", s, - ) else { + ) else { return Err(Error::Parse { err: ParseError::Format, object: "Uri".to_string(), }); }; - let did = match Authority::from_str(unchecked_authority)? { - Authority::Handle(h) => - return Err(Error::Parse { - err: ParseError::ForceDid { handle: h.to_string() }, - object: "Uri".to_string(), - }), - Authority::Did(d) => d, - }; + let did = Self::check_authority(unchecked_authority.to_string())?; let collection = if unchecked_collection.is_empty() { None } - else { Some(unchecked_collection.parse::()?) }; + else { Some(Self::check_collection(unchecked_collection.to_string())?) }; let rkey = if unchecked_rkey.is_empty() { None } - else { Some(unchecked_rkey.parse::()?) }; + else { Some(Self::check_rkey(unchecked_rkey.to_string())?) }; Ok(Uri { authority: did, collection, rkey }) } } impl Uri { + pub fn from_components( + authority_str: String, collection_str: Option, + rkey_str: Option + ) -> Result { + let authority = Self::check_authority(authority_str)?; + let collection = collection_str.map(Self::check_collection).transpose()?; + let rkey = rkey_str.map(Self::check_rkey).transpose()?; + let uri = Uri { authority, collection, rkey }; + Self::check_length(&uri.to_string())?; + + Ok(uri) + } + + fn check_length(s: &str) -> Result<(), Error> { + if s.len() > 8000 { + return Err(Error::Parse { + err: ParseError::Length { max: 8000, got: s.len() }, + object: "Did".to_string(), + }); + } + Ok(()) + } + + fn check_authority(authority: String) -> Result { + Ok(match Authority::from_str(&authority)? { + Authority::Handle(h) => + return Err(Error::Parse { + err: ParseError::ForceDid { handle: h.to_string() }, + object: "Uri".to_string(), + }), + Authority::Did(d) => d, + }) + } + + fn check_collection(collection: String) -> Result { + Ok(collection.parse::()?) + } + + fn check_rkey(rkey: String) -> Result { + Ok(rkey.parse::()?) + } + pub fn authority_as_did(&self) -> &Did { &self.authority } } +basic_deserializer!(Uri); diff --git a/entryway/Cargo.toml b/entryway/Cargo.toml new file mode 100644 index 0000000..e4b9f0f --- /dev/null +++ b/entryway/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "entryway" +version = "0.1.0" +edition = "2024" + +[dependencies] +atproto = { workspace = true, features = ["sqlx-support"] } +router.workspace = true +http = "1.3.1" +serde.workspace = true +serde_json.workspace = true +tokio.workspace = true +tracing-subscriber.workspace = true +tracing.workspace = true +async-trait.workspace = true +sqlx.workspace = true +thiserror.workspace = true +argon2 = "0.5" +time = { version = "0.3", features = ["formatting", "macros"] } diff --git a/entryway/migrations/01_initial_schema.sql b/entryway/migrations/01_initial_schema.sql new file mode 100644 index 0000000..a6b1ea8 --- /dev/null +++ b/entryway/migrations/01_initial_schema.sql @@ -0,0 +1,23 @@ +-- PDS Entryway Account Management Schema +-- Minimal schema for account creation and authentication + +-- Actor table - stores public identity information +CREATE TABLE actor ( + did VARCHAR PRIMARY KEY, + handle VARCHAR, + created_at VARCHAR NOT NULL +); + +-- Case-insensitive unique index on handle +CREATE UNIQUE INDEX actor_handle_lower_idx ON actor (LOWER(handle)); + +-- Account table - stores private authentication data +CREATE TABLE account ( + did VARCHAR PRIMARY KEY, + email VARCHAR NOT NULL, + password_scrypt VARCHAR NOT NULL, + email_confirmed_at VARCHAR +); + +-- Case-insensitive unique index on email +CREATE UNIQUE INDEX account_email_lower_idx ON account (LOWER(email)); diff --git a/entryway/migrations/20250828184830_initial_schema.sql b/entryway/migrations/20250828184830_initial_schema.sql new file mode 100644 index 0000000..a6b1ea8 --- /dev/null +++ b/entryway/migrations/20250828184830_initial_schema.sql @@ -0,0 +1,23 @@ +-- PDS Entryway Account Management Schema +-- Minimal schema for account creation and authentication + +-- Actor table - stores public identity information +CREATE TABLE actor ( + did VARCHAR PRIMARY KEY, + handle VARCHAR, + created_at VARCHAR NOT NULL +); + +-- Case-insensitive unique index on handle +CREATE UNIQUE INDEX actor_handle_lower_idx ON actor (LOWER(handle)); + +-- Account table - stores private authentication data +CREATE TABLE account ( + did VARCHAR PRIMARY KEY, + email VARCHAR NOT NULL, + password_scrypt VARCHAR NOT NULL, + email_confirmed_at VARCHAR +); + +-- Case-insensitive unique index on email +CREATE UNIQUE INDEX account_email_lower_idx ON account (LOWER(email)); diff --git a/entryway/src/database/error.rs b/entryway/src/database/error.rs new file mode 100644 index 0000000..855e280 --- /dev/null +++ b/entryway/src/database/error.rs @@ -0,0 +1,13 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum DatabaseError { + #[error("Database connection error: {0}")] + Connection(#[from] sqlx::Error), + + #[error("Handle already taken: {0}")] + HandleTaken(String), + + #[error("Email already taken: {0}")] + EmailTaken(String), +} \ No newline at end of file diff --git a/entryway/src/database/mod.rs b/entryway/src/database/mod.rs new file mode 100644 index 0000000..fc13e9a --- /dev/null +++ b/entryway/src/database/mod.rs @@ -0,0 +1,5 @@ +pub mod error; +pub mod operations; + +pub use error::DatabaseError; +pub use operations::Database; \ No newline at end of file diff --git a/entryway/src/database/operations.rs b/entryway/src/database/operations.rs new file mode 100644 index 0000000..b5ec548 --- /dev/null +++ b/entryway/src/database/operations.rs @@ -0,0 +1,82 @@ +use sqlx::{Pool, Postgres}; +use atproto::types::{ + Handle + Did +}; +use crate::database::DatabaseError; + +pub struct Database { + pool: Pool, +} + +impl Database { + pub fn new(pool: Pool) -> Self { + Self { pool } + } + + // Account availability checking + pub async fn check_handle_available(&self, handle: &Handle) -> Result { + let count = sqlx::query_scalar!( + "SELECT COUNT(*) FROM actor WHERE LOWER(handle) = LOWER($1)", + handle + ) + .fetch_one(&self.pool) + .await?; + + Ok(count.unwrap_or(0) == 0) + } + + pub async fn check_email_available(&self, email: &str) -> Result { + let count = sqlx::query_scalar!( + "SELECT COUNT(*) FROM account WHERE LOWER(email) = LOWER($1)", + email + ) + .fetch_one(&self.pool) + .await?; + + Ok(count.unwrap_or(0) == 0) + } + + // Account creation + pub async fn create_account( + &self, + did: &Did, + handle: &Handle, + email: &str, + password_hash: &str, + created_at: &str, + ) -> Result<(), DatabaseError> { + // Use a transaction to ensure both actor and account records are created together + let mut tx = self.pool.begin().await?; + + // Insert into actor table + sqlx::query!( + r#" + INSERT INTO actor (did, handle, created_at) + VALUES ($1, $2, $3) + "#, + did, + handle, + created_at + ) + .execute(&mut *tx) + .await?; + + // Insert into account table + sqlx::query!( + r#" + INSERT INTO account (did, email, password_scrypt) + VALUES ($1, $2, $3) + "#, + did, + email, + password_hash + ) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + + Ok(()) + } +} diff --git a/entryway/src/main.rs b/entryway/src/main.rs new file mode 100644 index 0000000..98153a4 --- /dev/null +++ b/entryway/src/main.rs @@ -0,0 +1,64 @@ +use router::{ + Router, + xrpc::XrpcEndpoint, +}; +use atproto::types::Nsid; +use sqlx::{Pool, Postgres}; +use std::env; +use tracing::{event, Level}; + +mod xrpc; +mod database; + +use xrpc::create_account; +use database::Database; + +struct Config { + entryway_url: String, + entryway_did: String, + entryway_plc_rotation_key: String, + entryway_jwt_key_256_hex: String, +} + +#[tokio::main] +async fn main() { + let subscriber = tracing_subscriber::FmtSubscriber::new(); + let _ = tracing::subscriber::set_global_default(subscriber); + + // Set up database connection + let database_url = env::var("DATABASE_URL") + .unwrap_or_else(|_| "postgres://localhost/entryway_dev".to_string()); + + event!(Level::INFO, "Connecting to database: {}", database_url); + + let pool = match sqlx::postgres::PgPoolOptions::new() + .max_connections(5) + .connect(&database_url) + .await + { + Ok(pool) => pool, + Err(e) => { + event!(Level::ERROR, "Failed to connect to database: {}", e); + std::process::exit(1); + } + }; + + let database = Database::new(pool); + + // Run migrations + if let Err(e) = database.run_migrations().await { + event!(Level::ERROR, "Failed to run migrations: {}", e); + std::process::exit(1); + } + + event!(Level::INFO, "Database setup complete"); + + // TODO: Wire up database to XRPC handlers + // For now, keeping the existing router setup + let mut router = Router::new(); + let create_account_nsid: Nsid = "com.atproto.server.createAccount".parse::().expect("valid nsid"); + router = router.add_endpoint(XrpcEndpoint::not_implemented()); + router = router.add_endpoint(XrpcEndpoint::new_procedure(create_account_nsid, create_account)); + router.serve().await; +} + diff --git a/entryway/src/xrpc/create_account.rs b/entryway/src/xrpc/create_account.rs new file mode 100644 index 0000000..7f0a0d6 --- /dev/null +++ b/entryway/src/xrpc/create_account.rs @@ -0,0 +1,258 @@ +use router::xrpc::{ProcedureInput, Response, error}; +use serde::{Deserialize, Serialize}; +use http::status::StatusCode; +use tracing::{event, instrument, Level}; +use atproto::types::Handle; +use std::str::FromStr; +use argon2::{Argon2, PasswordHasher, password_hash::{rand_core::OsRng, SaltString}}; +use time::OffsetDateTime; +use crate::database::{Database, DatabaseError}; + +#[derive(Deserialize, Debug)] +pub struct CreateAccountInput { + pub email: Option, + pub handle: String, + pub did: Option, + pub invite_code: Option, + pub verification_code: Option, + pub verification_phone: Option, + pub password: Option, + pub recovery_key: Option, + pub plc_op: Option, +} + +#[derive(Serialize, Debug)] +pub struct CreateAccountResponse { + pub handle: String, + pub did: String, + // pub did_doc: Option, // TODO: Define DidDocument type + pub access_jwt: String, + pub refresh_jwt: String, +} + +#[instrument] +pub async fn create_account(data: ProcedureInput) -> Response { + event!(Level::INFO, "Creating account for handle: {}", data.input.handle); + + // TODO: Get database from context/config + // For now, this won't compile but shows the intended flow + + // 1. Input validation + let validated_input = match validate_inputs(&data.input).await { + Ok(input) => input, + Err(err) => return err, + }; + + // 2. Check handle and email availability + // if let Err(err) = check_availability(&database, &validated_input).await { + // return err; + // } + + // 3. Generate DID (placeholder for now) + let did = generate_placeholder_did(&validated_input.handle).await; + + // 4. Hash password if provided + let password_hash = if let Some(password) = &validated_input.password { + match hash_password(password) { + Ok(hash) => Some(hash), + Err(_) => { + return error( + StatusCode::INTERNAL_SERVER_ERROR, + "InternalServerError", + "Failed to hash password" + ); + } + } + } else { + None + }; + + // 5. Create account in database + let created_at = OffsetDateTime::now_utc().format(&time::format_description::well_known::Iso8601::DEFAULT) + .unwrap_or_else(|_| "unknown".to_string()); + + // if let Err(err) = create_account_in_db(&database, &did, &validated_input, password_hash.as_deref(), &created_at).await { + // return convert_db_error_to_response(err); + // } + + // 6. Generate session tokens (placeholder for now) + let credentials = Credentials { + access_jwt: "placeholder_access_token".to_string(), + refresh_jwt: "placeholder_refresh_token".to_string(), + }; + + // Return success response + let response = CreateAccountResponse { + handle: validated_input.handle.clone(), + did: did.clone(), + access_jwt: credentials.access_jwt, + refresh_jwt: credentials.refresh_jwt, + }; + + event!(Level::INFO, "Account created successfully for DID: {} with handle: {}", did, validated_input.handle); + + // TODO: Replace with proper JSON response encoding + error(StatusCode::OK, "success", "Account created successfully") +} + +// Maximum password length (matches atproto TypeScript implementation) +const NEW_PASSWORD_MAX_LENGTH: usize = 256; + +// TODO: Implement these helper functions + +async fn validate_inputs(input: &CreateAccountInput) -> Result { + // Based on validateInputsForLocalPds in the TypeScript version + + // Validate email is provided and has basic format + let email = match &input.email { + Some(e) if !e.is_empty() => e.clone(), + _ => { + return Err(error( + StatusCode::BAD_REQUEST, + "InvalidRequest", + "Email is required" + )); + } + }; + + // Validate email format (basic validation for now) + // TODO: Improve email validation - add proper RFC validation and disposable email checking + // TypeScript version uses @hapi/address for validation and disposable-email-domains-js for disposable check + if !is_valid_email(&email) { + return Err(error( + StatusCode::BAD_REQUEST, + "InvalidRequest", + "This email address is not supported, please use a different email." + )); + } + + // Validate password length if provided + if let Some(password) = &input.password { + if password.len() > NEW_PASSWORD_MAX_LENGTH { + return Err(error( + StatusCode::BAD_REQUEST, + "InvalidRequest", + &format!("Password too long. Maximum length is {} characters.", NEW_PASSWORD_MAX_LENGTH) + )); + } + } + + // Validate and normalize handle using atproto types + let handle = Handle::from_str(&input.handle).map_err(|_| { + error( + StatusCode::BAD_REQUEST, + "InvalidRequest", + "Invalid handle format" + ) + })?; + + // TODO: Invite codes - not supported for now but leave placeholder + if input.invite_code.is_some() { + event!(Level::INFO, "Invite codes not yet supported, ignoring"); + } + + Ok(ValidatedInput { + handle: handle.to_string(), + email: email.to_lowercase(), // Normalize email to lowercase + password: input.password.clone(), + invite_code: input.invite_code.clone(), + }) +} + +// Basic email validation - checks for @ and . in reasonable positions +// TODO: Replace with proper email validation library like email-address crate +fn is_valid_email(email: &str) -> bool { + // Very basic email validation + let at_pos = email.find('@'); + let last_dot_pos = email.rfind('.'); + + match (at_pos, last_dot_pos) { + (Some(at), Some(dot)) => { + // @ must come before the last dot + // Must have content before @, between @ and dot, and after dot + at > 0 && dot > at + 1 && dot < email.len() - 1 + } + _ => false, + } +} + +async fn check_availability(database: &Database, input: &ValidatedInput) -> Result<(), Response> { + // Check that handle and email are not already taken + match database.check_handle_available(&input.handle).await { + Ok(false) => { + return Err(error( + StatusCode::BAD_REQUEST, + "InvalidRequest", + &format!("Handle already taken: {}", input.handle) + )); + } + Err(err) => { + event!(Level::ERROR, "Database error checking handle availability: {:?}", err); + return Err(error( + StatusCode::INTERNAL_SERVER_ERROR, + "InternalServerError", + "Database error" + )); + } + _ => {} + } + + match database.check_email_available(&input.email).await { + Ok(false) => { + return Err(error( + StatusCode::BAD_REQUEST, + "InvalidRequest", + &format!("Email already taken: {}", input.email) + )); + } + Err(err) => { + event!(Level::ERROR, "Database error checking email availability: {:?}", err); + return Err(error( + StatusCode::INTERNAL_SERVER_ERROR, + "InternalServerError", + "Database error" + )); + } + _ => {} + } + + Ok(()) +} + +async fn generate_placeholder_did(handle: &str) -> String { + // TODO: Replace with actual DID generation (did:plc) + // For now, generate a placeholder DID based on handle + format!("did:placeholder:{}", handle.replace(".", "-")) +} + +fn hash_password(password: &str) -> Result { + let salt = SaltString::generate(&mut OsRng); + let argon2 = Argon2::default(); + let password_hash = argon2.hash_password(password.as_bytes(), &salt)?; + Ok(password_hash.to_string()) +} + +async fn create_account_in_db( + database: &Database, + did: &str, + input: &ValidatedInput, + password_hash: Option<&str>, + created_at: &str, +) -> Result<(), DatabaseError> { + let hash = password_hash.unwrap_or(""); // Empty hash if no password + database.create_account(did, &input.handle, &input.email, hash, created_at).await +} + +#[derive(Debug)] +struct ValidatedInput { + handle: String, + email: String, + password: Option, + invite_code: Option, +} + +#[derive(Debug)] +struct Credentials { + access_jwt: String, + refresh_jwt: String, +} \ No newline at end of file diff --git a/entryway/src/xrpc/mod.rs b/entryway/src/xrpc/mod.rs new file mode 100644 index 0000000..96b9473 --- /dev/null +++ b/entryway/src/xrpc/mod.rs @@ -0,0 +1,3 @@ +pub mod create_account; + +pub use create_account::create_account; \ No newline at end of file diff --git a/flake.lock b/flake.lock index 69e701e..c3be24b 100644 --- a/flake.lock +++ b/flake.lock @@ -2,16 +2,18 @@ "nodes": { "nixpkgs": { "locked": { - "lastModified": 1735563628, - "narHash": "sha256-OnSAY7XDSx7CtDoqNh8jwVwh4xNL/2HaJxGjryLWzX8=", - "rev": "b134951a4c9f3c995fd7be05f3243f8ecd65d798", - "revCount": 637546, - "type": "tarball", - "url": "https://api.flakehub.com/f/pinned/NixOS/nixpkgs/0.2405.637546%2Brev-b134951a4c9f3c995fd7be05f3243f8ecd65d798/01941dc2-2ab2-7453-8ebd-88712e28efae/source.tar.gz" + "lastModified": 1752436162, + "narHash": "sha256-Kt1UIPi7kZqkSc5HVj6UY5YLHHEzPBkgpNUByuyxtlw=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "dfcd5b901dbab46c9c6e80b265648481aafb01f8", + "type": "github" }, "original": { - "type": "tarball", - "url": "https://flakehub.com/f/NixOS/nixpkgs/0.2405.%2A.tar.gz" + "owner": "NixOS", + "ref": "nixos-25.05", + "repo": "nixpkgs", + "type": "github" } }, "nixpkgs_2": { @@ -41,11 +43,11 @@ "nixpkgs": "nixpkgs_2" }, "locked": { - "lastModified": 1749695868, - "narHash": "sha256-debjTLOyqqsYOUuUGQsAHskFXH5+Kx2t3dOo/FCoNRA=", + "lastModified": 1752547600, + "narHash": "sha256-0vUE42ji4mcCvQO8CI0Oy8LmC6u2G4qpYldZbZ26MLc=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "55f914d5228b5c8120e9e0f9698ed5b7214d09cd", + "rev": "9127ca1f5a785b23a2fc1c74551a27d3e8b9a28b", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 48d94bf..014140f 100644 --- a/flake.nix +++ b/flake.nix @@ -3,13 +3,15 @@ # Flake inputs inputs = { - nixpkgs.url = "https://flakehub.com/f/NixOS/nixpkgs/0.2405.*.tar.gz"; + nixpkgs.url = "github:NixOS/nixpkgs/nixos-25.05"; rust-overlay.url = "github:oxalica/rust-overlay"; # A helper for Rust + Nix }; # Flake outputs outputs = { self, nixpkgs, rust-overlay }: let + pdsDirectory = "/home/pan/prog/atproto/appview"; + # Overlays enable you to customize the Nixpkgs attribute set overlays = [ # Makes a `rust-bin` attribute available in Nixpkgs @@ -33,19 +35,88 @@ forAllSystems = f: nixpkgs.lib.genAttrs allSystems (system: f { pkgs = import nixpkgs { inherit overlays system; }; }); + + # Systemd service configuration + createSystemdService = pkgs: pdsDir: pkgs.writeTextFile { + name = "pds.service"; + text = '' + [Unit] + Description=Development Environment Service + After=network-online.target + Wants=network-online.target + + [Service] + Type=simple + ExecStart=${pkgs.pds}/bin/pds + WorkingDirectory=${pdsDir} + EnvironmentFile=${pdsDir}/.env + Environment=PDS_DATA_DIRECTORY=${pdsDir}/.pds-data + Environment=PDS_BLOBSTORE_DISK_LOCATION=${pdsDir}/.pds-data/blocks + ''; + }; + + # Scripts for managing the systemd service + createServiceScripts = pkgs: pdsDir: + let + serviceFile = createSystemdService pkgs pdsDir; + serviceName = "pds"; + in { + startScript = pkgs.writeShellScript "start-dev-service" '' + set -e + + # Create user systemd directory if it doesn't exist + mkdir -p ~/.config/systemd/user + + # Copy service file + cp -f ${serviceFile} ~/.config/systemd/user/${serviceName}.service + + # Reload systemd and start service + systemctl --user daemon-reload + systemctl --user start ${serviceName} + systemctl --user enable ${serviceName} + + systemctl --user status ${serviceName} --no-pager + ''; + + stopScript = pkgs.writeShellScript "stop-dev-service" '' + set -e + if systemctl --user is-enabled --quiet ${serviceName}; then + # Stop and disable service + systemctl --user stop ${serviceName} || true + systemctl --user disable ${serviceName} || true + + # Remove service file + rm -f ~/.config/systemd/user/${serviceName}.service + + # Reload systemd + systemctl --user daemon-reload + fi + ''; + }; in { # Development environment output - devShells = forAllSystems ({ pkgs }: { - default = pkgs.mkShell { - # The Nix packages provided in the environment - packages = (with pkgs; [ - # The package provided by our custom overlay. Includes cargo, Clippy, cargo-fmt, - # rustdoc, rustfmt, and other tools. - sqlx-cli - rustToolchain - ]) ++ pkgs.lib.optionals pkgs.stdenv.isDarwin (with pkgs; [ libiconv ]); - }; - }); + devShells = forAllSystems ({ pkgs }: + let + scripts = createServiceScripts pkgs pdsDirectory; + in { + default = pkgs.mkShell { + # The Nix packages provided in the environment + packages = (with pkgs; [ + # The package provided by our custom overlay. Includes cargo, Clippy, cargo-fmt, + # rustdoc, rustfmt, and other tools. + sqlx-cli + rustToolchain + ]) ++ pkgs.lib.optionals pkgs.stdenv.isDarwin (with pkgs; [ libiconv ]); + + shellHook = pkgs.lib.optionalString pkgs.stdenv.isLinux '' + # Cleanup + ${scripts.stopScript} + + # Start the systemd service + ${scripts.startScript} + ''; + }; + }); }; } diff --git a/router/Cargo.toml b/router/Cargo.toml new file mode 100644 index 0000000..2426dd6 --- /dev/null +++ b/router/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "router" +version = "0.1.0" +edition = "2024" + +[dependencies] +atproto.workspace = true +axum = { version = "0.8.3", features = ["json"] } +bon = "3.6.4" +http = "1.3.1" +serde.workspace = true +serde_json.workspace = true +tokio.workspace = true +tracing-subscriber.workspace = true +tracing.workspace = true diff --git a/router/src/lib.rs b/router/src/lib.rs new file mode 100644 index 0000000..788556b --- /dev/null +++ b/router/src/lib.rs @@ -0,0 +1,50 @@ +use crate::xrpc::XrpcEndpoint; +use axum::Router as AxumRouter; +use core::net::SocketAddr; +use std::net::{IpAddr, Ipv4Addr}; +use tokio::net::TcpListener; + +pub mod xrpc; +pub mod wellknown; + +pub enum Error {} + +pub struct Router { + addr: SocketAddr, + router: AxumRouter, +} +impl Default for Router { + fn default() -> Self { + Self::new() + } +} +impl Router +where + S: Clone + Send + Sync + 'static, +{ + pub fn new() -> Self { + let router = AxumRouter::new(); + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127,0,0,1)), 6702); + Router { router, addr } + } + + pub fn with_state(mut self, state: S) -> Router { + self.router = self.with_state(S); + self + } + + pub fn add_endpoint(mut self, endpoint: E) -> Self { + self.router = endpoint.add_to_router(self.router); + self + } + + pub async fn serve(self) { + let listener = TcpListener::bind(self.addr).await.unwrap(); + + axum::serve(listener, self.router).await.unwrap(); + } +} + +pub trait Endpoint { + fn add_to_router(self, router: AxumRouter) -> AxumRouter; +} diff --git a/router/src/wellknown.rs b/router/src/wellknown.rs new file mode 100644 index 0000000..394a794 --- /dev/null +++ b/router/src/wellknown.rs @@ -0,0 +1,19 @@ +use crate::Endpoint; +use axum::{ + routing::method_routing::MethodRouter, + Router as axumRouter, +}; + +pub mod atproto; +pub mod oauth; + +trait WellKnownEndpoint { + fn get_known_route(&self) -> String; + fn get_resolver(self) -> MethodRouter; +} + +impl Endpoint for WK { + fn add_to_router(self, router: axumRouter) -> axumRouter { + router.route(&format!(".well-known/{}", self.get_known_route()), self.get_resolver()) + } +} diff --git a/router/src/wellknown/atproto.rs b/router/src/wellknown/atproto.rs new file mode 100644 index 0000000..0f99da0 --- /dev/null +++ b/router/src/wellknown/atproto.rs @@ -0,0 +1 @@ +pub mod handle_resolution; diff --git a/router/src/wellknown/atproto/handle_resolution.rs b/router/src/wellknown/atproto/handle_resolution.rs new file mode 100644 index 0000000..b3d75ce --- /dev/null +++ b/router/src/wellknown/atproto/handle_resolution.rs @@ -0,0 +1,52 @@ +use crate::{ + wellknown::WellKnownEndpoint, + Error, +}; +use atproto::types::{Handle, Did}; +use axum::{ + routing::{ + method_routing::MethodRouter, + get, + }, + http::{ + StatusCode, + HeaderMap, + }, +}; + +pub struct HandleResolutionEndpoint { + resolver: MethodRouter, +} + +impl HandleResolutionEndpoint { + pub fn new


(handle_resolver: HR) -> Self where + HR: HandleResolver + Clone + { + HandleResolutionEndpoint { + resolver: get(async move | headers: HeaderMap | -> (StatusCode, String) { + let Some(Ok(hostname)) = headers.get("host").map(|header_value| { + header_value.to_str() + }) else { + return (StatusCode::INTERNAL_SERVER_ERROR, String::from("Internal Server Error")); + }; + let Ok(valid_handle) = hostname.parse::() else { + return (StatusCode::NOT_FOUND, String::from("User not found")); + }; + match handle_resolver.call(valid_handle) { + Ok(Some(did)) => (StatusCode::OK, did.to_string()), + Ok(None) => (StatusCode::NOT_FOUND, String::from("User not found")), + Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, String::from("Internal Server Error")), + } + }) + } + } +} + +pub trait HandleResolver: Send + Sync + 'static { + fn call(&self, handle: Handle) -> Result, Error>; +} + +impl WellKnownEndpoint for HandleResolutionEndpoint { + fn get_known_route(&self) -> String { String::from("atproto-did") } + fn get_resolver(self) -> MethodRouter { self.resolver } +} diff --git a/router/src/wellknown/oauth.rs b/router/src/wellknown/oauth.rs new file mode 100644 index 0000000..887d5fd --- /dev/null +++ b/router/src/wellknown/oauth.rs @@ -0,0 +1 @@ +pub mod authorization_server; diff --git a/router/src/wellknown/oauth/authorization_server.rs b/router/src/wellknown/oauth/authorization_server.rs new file mode 100644 index 0000000..d2ae12f --- /dev/null +++ b/router/src/wellknown/oauth/authorization_server.rs @@ -0,0 +1,103 @@ +use serde::de::{ + Error, + Unexpected, +}; +use serde_json::{ + json, + Result, + Value +}; +use bon::Builder; + +trait Metadata { + fn format_metadata(self, required: RequiredMetadata) -> Result; +} + +pub struct RequiredMetadata { + issuer: String, + authorization_endpoint: String, + token_endpoint: String, +} + +impl RequiredMetadata { + fn new( + issuer: String, + authorization_endpoint: String, + token_endpoint: String + ) -> Self { + RequiredMetadata { + issuer, authorization_endpoint, token_endpoint + } + } +} + +#[derive(Builder)] +struct AtprotoMetadata { + additional_response_types_supported: Option>, + additional_grant_types_supported: Option>, + additional_code_challenge_methods_supported: Option>, + additional_token_endpoint_auth_methods_supported: Option>, + additional_token_endpoint_auth_signing_alg_values_supported: Option>, + additional_scopes_supported: Option>, + pushed_authorization_request_endpoint: String, + additional_dpop_signing_alg_values_supported: Option>, +} + +impl AtprotoMetadata { + fn check_fields(&self) -> Result<()> { + // TODO: Issuer check (https scheme, no default port, no path segments + + if self.additional_token_endpoint_auth_signing_alg_values_supported + .as_ref() + .is_none_or(|vec| vec.iter().any(|s| s == "none")) { + return Err(Error::invalid_value( + Unexpected::Other("\"none\" in token_endpoint_auth_signing_alg_values_supported"), + &"\"none\" to be omitted from token_endpoint_auth_signing_alg_values_supported" + )); + } + + Ok(()) + } +} + +impl Metadata for AtprotoMetadata { + fn format_metadata(self, required: RequiredMetadata) -> Result { + self.check_fields()?; + Ok(json!({ + "issuer": required.issuer, + "authorization_endpoint": required.authorization_endpoint, + "token_endpoint": required.token_endpoint, + "response_types_supported": + self.additional_response_types_supported.unwrap_or_default() + .extend(["code".to_string()]), + "grant_types_supported": + self.additional_grant_types_supported.unwrap_or_default() + .extend([ + "authorization_code".to_string(), + "refresh_token".to_string() + ]), + "code_challenge_methods_supported": + self.additional_code_challenge_methods_supported.unwrap_or_default() + .extend(["S256".to_string()]), + "token_endpoint_auth_methods_supported": + self.additional_token_endpoint_auth_methods_supported.unwrap_or_default() + .extend([ + "none".to_string(), + "private_key_jwt".to_string() + ]), + "token_endpoint_auth_signing_alg_values_supported": + self.additional_token_endpoint_auth_signing_alg_values_supported.unwrap_or_default() + .extend(["ES256".to_string()]), + "scopes_supported": + self.additional_scopes_supported.unwrap_or_default() + .extend(["atproto".to_string()]), + "authorization_response_iss_parameter_supported": true, + "require_pushed_authorization_requests": true, + "pushed_authorization_request_endpoint": self.pushed_authorization_request_endpoint, + "dpop_signing_alg_values_supported": + self.additional_dpop_signing_alg_values_supported.unwrap_or_default() + .extend(["ES256".to_string()]), + "client_id_metadata_document_supported": true, + })) + } +} diff --git a/api/src/router/xrpc.rs b/router/src/xrpc.rs similarity index 60% rename from api/src/router/xrpc.rs rename to router/src/xrpc.rs index 500f331..ff62711 100644 --- a/api/src/router/xrpc.rs +++ b/router/src/xrpc.rs @@ -1,9 +1,10 @@ +use crate::Endpoint; use std::{ collections::HashMap, pin::Pin, future::Future, }; -use atproto::Nsid; +use atproto::types::Nsid; use axum::{ extract::{ Json, @@ -51,8 +52,11 @@ pub fn response(code: StatusCode, message: &str) -> Response { error(code, "", message) } -pub struct QueryInput { - parameters: HashMap, +pub struct QueryInput +where S: Clone + Send + Sync + 'static, +{ + pub parameters: HashMap, + pub state: S, } impl FromRequestParts for QueryInput where @@ -60,21 +64,28 @@ where { type Rejection = Response; - async fn from_request_parts(parts: &mut Parts, _state: &S) + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let query_params: Result>, QueryRejection> = Query::try_from_uri(&parts.uri); match query_params { - Ok(p) => Ok(QueryInput { parameters: p.0 }), + Ok(p) => Ok(QueryInput { parameters: p.0, state }), Err(e) => Err(error(StatusCode::BAD_REQUEST, "Bad Parameters", &e.body_text())), } } } -pub struct ProcedureInput { - parameters: HashMap, - input: Json, + +#[derive(Debug)] +pub struct ProcedureInput +where S: Clone + Send + Sync + 'static, +{ + pub parameters: HashMap, + pub input: J, + pub state: S, } -impl FromRequest for ProcedureInput + +impl FromRequest for ProcedureInput where + J: for<'de> serde::Deserialize<'de> + Send + 'static, Bytes: FromRequest, S: Send + Sync, { @@ -82,19 +93,15 @@ where async fn from_request(req: Request, state: &S) -> Result { - let query_params: Result>, QueryRejection> = Query::try_from_uri(req.uri()); - let parameters = match query_params { - Ok(p) => p.0, - Err(e) => return Err(error(StatusCode::BAD_REQUEST, "Bad Parameters", &e.body_text())), - }; + let parameters: HashMap = Query::try_from_uri(req.uri()) + .map(|p| p.0) + .map_err(|e| error(StatusCode::BAD_REQUEST, "Bad Paramters", &e.body_text()))?; - let json_value = Json::::from_request(req, state).await; - let input: Json = match json_value { - Ok(v) => v, - Err(e) => return Err(error(StatusCode::BAD_REQUEST, "Bad Parameters", &e.body_text())), - }; + let input: J = Json::::from_request(req, state).await + .map(|Json(v)| v) + .map_err(|e| error(StatusCode::BAD_REQUEST, "Bad Parameters", &e.body_text()))?; - Ok(ProcedureInput { parameters, input }) + Ok(ProcedureInput { parameters, input, state }) } } @@ -112,26 +119,26 @@ where Box::pin((self)(input)) } } -impl XrpcHandler for F +impl XrpcHandler> for F where - F: Fn(ProcedureInput) -> Fut + Send + Sync + 'static, + F: Fn(ProcedureInput) -> Fut + Send + Sync + 'static, Fut: Future + Send + 'static, { - fn call(&self, input: ProcedureInput) + fn call(&self, input: ProcedureInput) -> Pin+ Send>> { Box::pin((self)(input)) } } impl XrpcEndpoint { - pub fn new_query(nsid: Nsid, query: Q) -> Self + pub fn new_query(nsid: Nsid, query: Q) -> Self where Q: XrpcHandler + Clone { XrpcEndpoint { path: Path::Nsid(nsid), - resolver: get(async move | mut parts: Parts | -> Response { - match QueryInput::from_request_parts(&mut parts, &()).await { + resolver: get(async move | mut parts: Parts, state: &S | -> Response { + match QueryInput::from_request_parts(&mut parts, state).await { Ok(qi) => query.call(qi).await, Err(e) => e } @@ -139,14 +146,15 @@ impl XrpcEndpoint { } } - pub fn new_procedure

(nsid: Nsid, procedure: P) -> Self + pub fn new_procedure(nsid: Nsid, procedure: P) -> Self where - P: XrpcHandler + Clone + P: XrpcHandler> + Clone, + J: for<'de> serde::Deserialize<'de> + Send + 'static, { XrpcEndpoint { path: Path::Nsid(nsid), - resolver: post(async move | req: Request | -> Response { - match ProcedureInput::from_request(req, &()).await { + resolver: post(async move | req: Request, state: &S | -> Response { + match ProcedureInput::::from_request(req, &state).await { Ok(pi) => procedure.call(pi).await, Err(e) => e } @@ -154,15 +162,6 @@ impl XrpcEndpoint { } } - pub fn add_to_router(self, router: axumRouter) -> axumRouter { - let path = match self.path { - Path::Nsid(nsid) => &("/xrpc/".to_owned() + nsid.as_str()), - Path::NotImplemented => "/xrpc/{*nsid}", - }; - - router.route(path, self.resolver) - } - pub fn not_implemented() -> Self { let resolver = ( StatusCode::NOT_IMPLEMENTED, @@ -179,3 +178,13 @@ impl XrpcEndpoint { } } +impl Endpoint for XrpcEndpoint { + fn add_to_router(self, router: axumRouter) -> axumRouter { + let path = match self.path { + Path::Nsid(nsid) => &("/xrpc/".to_owned() + &nsid.to_string()), + Path::NotImplemented => "/xrpc/{*nsid}", + }; + + router.route(path, self.resolver) + } +}