diff --git a/Cargo.lock b/Cargo.lock index 56d7dda..73f43bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -118,6 +118,7 @@ dependencies = [ "lazy-regex", "serde", "serde_json", + "sqlx", "thiserror 2.0.12", "time", "tracing", @@ -566,6 +567,7 @@ name = "db" version = "0.1.0" dependencies = [ "async-trait", + "atproto", "sqlx", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 779f393..c029fb0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ async-trait = "0.1.88" atproto = { path = "./atproto" } serde = "1.0.219" serde_json = "1.0.140" +sqlx = { version = "0.8.6", features = ["postgres", "runtime-tokio"] } thiserror = "2.0.12" tokio = { version = "1.45.0", features = ["macros", "rt-multi-thread"] } tracing = "0.1.41" diff --git a/atproto/Cargo.toml b/atproto/Cargo.toml index 033ff0d..3f5445b 100644 --- a/atproto/Cargo.toml +++ b/atproto/Cargo.toml @@ -8,7 +8,12 @@ atrium-api = { version = "0.25.3", default-features = false } lazy-regex = "3.4.1" serde.workspace = true serde_json.workspace = true +sqlx = { workspace = true, optional = true } time = { version = "0.3.41", features = ["parsing", "formatting"] } tracing-subscriber.workspace = true tracing.workspace = true thiserror.workspace = true + +[features] +default = [] +sqlx-support = ["dep:sqlx"] diff --git a/atproto/src/lib.rs b/atproto/src/lib.rs index c3b56ee..53c8d32 100644 --- a/atproto/src/lib.rs +++ b/atproto/src/lib.rs @@ -1,3 +1,5 @@ pub mod lexicons; pub mod types; pub mod error; +#[cfg(feature = "sqlx-support")] +pub mod sqlx; diff --git a/atproto/src/sqlx.rs b/atproto/src/sqlx.rs new file mode 100644 index 0000000..0bb6683 --- /dev/null +++ b/atproto/src/sqlx.rs @@ -0,0 +1,38 @@ +use crate::types::{ + Did, + Cid, + Uri, + Handle, + Datetime, +}; + +macro_rules! implement_sqlx_for_string_type { + ($name:ident) => { + impl sqlx::Type for $name { + fn type_info() -> sqlx::postgres::PgTypeInfo { + >::type_info() + } + } + impl<'q> sqlx::Encode<'q, sqlx::Postgres> for $name { + fn encode_by_ref( + &self, buf: &mut sqlx::postgres::PgArgumentBuffer + ) -> Result { + >::encode_by_ref(&self.to_string(), buf) + } + } + impl<'r> sqlx::Decode<'r, sqlx::Postgres> for $name { + fn decode( + value: sqlx::postgres::PgValueRef<'r> + ) -> Result { + let s = >::decode(value)?; + s.parse::<$name>().map_err(|e| Box::new(e) as sqlx::error::BoxDynError) + } + } + } +} + +implement_sqlx_for_string_type!(Did); +implement_sqlx_for_string_type!(Cid); +implement_sqlx_for_string_type!(Uri); +implement_sqlx_for_string_type!(Handle); +implement_sqlx_for_string_type!(Datetime); diff --git a/db/Cargo.toml b/db/Cargo.toml index d555eb9..49924cf 100644 --- a/db/Cargo.toml +++ b/db/Cargo.toml @@ -5,5 +5,6 @@ edition = "2024" [dependencies] async-trait.workspace = true +atproto = { workspace = true, features = ["sqlx-support"] } sqlx = { version = "0.8.6", features = ["postgres", "runtime-tokio"] } tokio.workspace = true