From 781a56028f0a00d3f80334fd73e08adbd87cc073 Mon Sep 17 00:00:00 2001 From: Julia Lange Date: Mon, 16 Jun 2025 15:29:27 -0700 Subject: [PATCH] Atproto, types overhaul and error handling Breaks off from Atrium-rs's types because they are implemented inconsistently, which makes them harder to use. Additionally, I wanted sqlx support so I decided I'd need to reimplement them to some extent anyways. This was done with reference to the atproto documentation but specifically not the atrium-rs codebase so I wouldn't have to think about licenses. This adds the types and error module in atproto. It also touches Cargo.toml for some new dependencies and some shared dependencies. It required thiserror, so I looped that into the workspace meaning that this commit touches db. some things to keep in mind: - There is no CID parsing - None of this is tested, nor are there any tests written. We're playing fast and loose baby~ --- Cargo.toml | 1 + atproto/Cargo.toml | 2 + atproto/src/error.rs | 31 +++++++++++++ atproto/src/lib.rs | 81 +-------------------------------- atproto/src/types.rs | 70 ++++++++++++++++++++++++++++ atproto/src/types/authority.rs | 38 ++++++++++++++++ atproto/src/types/cid.rs | 16 +++++++ atproto/src/types/datetime.rs | 63 +++++++++++++++++++++++++ atproto/src/types/did.rs | 72 +++++++++++++++++++++++++++++ atproto/src/types/record_key.rs | 63 +++++++++++++++++++++++++ atproto/src/types/uri.rs | 79 ++++++++++++++++++++++++++++++++ db/Cargo.toml | 4 +- 12 files changed, 439 insertions(+), 81 deletions(-) create mode 100644 atproto/src/error.rs create mode 100644 atproto/src/types.rs create mode 100644 atproto/src/types/authority.rs create mode 100644 atproto/src/types/cid.rs create mode 100644 atproto/src/types/datetime.rs create mode 100644 atproto/src/types/did.rs create mode 100644 atproto/src/types/record_key.rs create mode 100644 atproto/src/types/uri.rs diff --git a/Cargo.toml b/Cargo.toml index 9129ba7..779f393 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ async-trait = "0.1.88" atproto = { path = "./atproto" } serde = "1.0.219" serde_json = "1.0.140" +thiserror = "2.0.12" tokio = { version = "1.45.0", features = ["macros", "rt-multi-thread"] } tracing = "0.1.41" tracing-subscriber = "0.3.19" diff --git a/atproto/Cargo.toml b/atproto/Cargo.toml index 7d56725..033ff0d 100644 --- a/atproto/Cargo.toml +++ b/atproto/Cargo.toml @@ -8,5 +8,7 @@ atrium-api = { version = "0.25.3", default-features = false } lazy-regex = "3.4.1" serde.workspace = true serde_json.workspace = true +time = { version = "0.3.41", features = ["parsing", "formatting"] } tracing-subscriber.workspace = true tracing.workspace = true +thiserror.workspace = true diff --git a/atproto/src/error.rs b/atproto/src/error.rs new file mode 100644 index 0000000..1db7106 --- /dev/null +++ b/atproto/src/error.rs @@ -0,0 +1,31 @@ +use thiserror::Error as ThisError; + +#[non_exhaustive] +#[derive(Debug, ThisError)] +pub enum Error { + #[error("Error while parsing")] + Parse { err: ParseError, object: String }, + #[error("Error while formatting")] + Format { err: FormatError, object: String }, +} + +#[non_exhaustive] +#[derive(Debug, ThisError)] +pub enum FormatError { + #[error("Time Parse Error: {0}")] + Datetime(#[from] time::error::Format), +} +#[non_exhaustive] +#[derive(Debug, ThisError)] +pub enum ParseError { + #[error("Time Parse Error: {0}")] + Datetime(#[from] time::error::Parse), + #[error("Length of parsed object too long, max: {max:?}, got: {got:?}.")] + Length { max: usize, got: usize }, + #[error("Currently Did is enforced, cannot use handle, {handle:?}")] + ForceDid { handle: String }, + #[error("Incorrectly formatted")] + Format, +} + +pub type Result = std::result::Result; diff --git a/atproto/src/lib.rs b/atproto/src/lib.rs index a327896..c3b56ee 100644 --- a/atproto/src/lib.rs +++ b/atproto/src/lib.rs @@ -1,80 +1,3 @@ -use lazy_regex::regex_captures; -use core::str::FromStr; -use std::fmt::{ - Display, Formatter, Result as FmtResult -}; - -pub use atrium_api::types::{ - Collection, - string::{ - AtIdentifier as Authority, - Datetime, - Did, - Nsid, - RecordKey, - Tid, - Handle, - } -}; - pub mod lexicons; - -pub struct Cid(String); - -impl Display for Cid { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - write!(f, "{}", self.0) - } -} - -pub struct StrongRef { - pub content: T, - pub cid: Cid, -} - -pub struct Uri { - whole: String, - // These fields could be useful in the future, - // so I'm leaving the code for them. - // authority: Authority, - // collection: Option, - // rkey: Option, -} - -impl FromStr for Uri { - type Err = &'static str; - fn from_str(uri: &str) -> Result { - if uri.len() > 8000 { - return Err("Uri too long") - } - - let Some(( - whole, unchecked_authority, unchecked_collection, unchecked_rkey - )) = regex_captures!( - r"/^at:\/\/([\w\.\-_~:]+)(?:\/([\w\.\-_~:]+)(?:)\/([\w\.\-_~:]+))?$/i", - uri, - ) else { - return Err("Invalid Uri"); - }; - - // This parsing is required, but the values don't need to be used yet. - // No compute cost to use them, just storage cost - let _authority = Authority::from_str(unchecked_authority)?; - - let _collection = if unchecked_collection.is_empty() { None } - else { Some(Nsid::new(unchecked_collection.to_string())?) }; - - let _rkey = if unchecked_rkey.is_empty() { None } - else { Some(RecordKey::new(unchecked_rkey.to_string())?) }; - - // Ok(Uri{ whole: whole.to_string(), authority, collection, rkey }) - Ok(Uri { whole: whole.to_string() }) - } -} - -impl Uri { - pub fn as_str(&self) -> &str { - self.whole.as_str() - } -} - +pub mod types; +pub mod error; diff --git a/atproto/src/types.rs b/atproto/src/types.rs new file mode 100644 index 0000000..ca30ded --- /dev/null +++ b/atproto/src/types.rs @@ -0,0 +1,70 @@ +use crate::error::{Error, ParseError}; + +macro_rules! basic_string_type { + ($name:ident, $regex:literal, $max_len:literal) => { + pub struct $name { value: String, } + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.value) + } + } + + impl std::str::FromStr for $name { + type Err = Error; + fn from_str(s: &str) -> Result { + if s.len() > $max_len { + return Err(Error::Parse { + err: ParseError::Length { max: s.len(), got: $max_len }, + object: stringify!($name).to_string(), + }); + } + + if ! lazy_regex::regex_is_match!( + $regex, + s + ) { + return Err(Error::Parse { + err: ParseError::Format, + object: stringify!($name).to_string(), + }); + } + + Ok(Self { + value: s.to_string(), + }) + } + } + } +} + +mod did; +pub use did::Did; +mod cid; +pub use cid::Cid; +mod authority; +pub use authority::Authority; +mod datetime; +pub use datetime::Datetime; +mod record_key; +pub use record_key::RecordKey; +mod uri; +pub use uri::Uri; + +basic_string_type!(Handle, + r"^([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$", + 253 +); +basic_string_type!(Nsid, + r"^[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)+(\.[a-zA-Z]([a-zA-Z0-9]{0,62})?)$", + 317 +); +basic_string_type!(Tid, + r"^[234567abcdefghij][234567abcdefghijklmnopqrstuvwxyz]{12}$", + 13 +); + +pub struct StrongRef { + content: T, + cid: Cid, +} diff --git a/atproto/src/types/authority.rs b/atproto/src/types/authority.rs new file mode 100644 index 0000000..1c76750 --- /dev/null +++ b/atproto/src/types/authority.rs @@ -0,0 +1,38 @@ +use crate::{ + types::{Did, Handle}, + error::{Error, ParseError}, +}; +use std::{ + fmt::{Display, Formatter, Result as FmtResult}, + str::FromStr, +}; + +pub enum Authority { + Did(Did), + Handle(Handle), +} + +impl Display for Authority { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{}", match self { + Authority::Did(did) => did.to_string(), + Authority::Handle(handle) => handle.to_string(), + }) + } +} + +impl FromStr for Authority { + type Err = Error; + fn from_str(s: &str) -> Result { + if let Ok(did) = s.parse::() { + return Ok(Authority::Did(did)); + } + if let Ok(did) = s.parse::() { + return Ok(Authority::Handle(did)); + } + Err(Error::Parse { + err: ParseError::Format, + object: "Authority".to_string(), + }) + } +} diff --git a/atproto/src/types/cid.rs b/atproto/src/types/cid.rs new file mode 100644 index 0000000..3581ad2 --- /dev/null +++ b/atproto/src/types/cid.rs @@ -0,0 +1,16 @@ +use crate::error::Error; + +pub struct Cid { value: String, } + +impl std::fmt::Display for Cid { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.value) + } +} + +impl std::str::FromStr for Cid { + type Err = Error; + fn from_str(s: &str) -> Result { + Ok(Self { value: s.to_string() }) + } +} diff --git a/atproto/src/types/datetime.rs b/atproto/src/types/datetime.rs new file mode 100644 index 0000000..6dcf234 --- /dev/null +++ b/atproto/src/types/datetime.rs @@ -0,0 +1,63 @@ +use crate::error::{Error, ParseError, FormatError}; +use time::{ + UtcDateTime, + format_description::well_known::{Rfc3339, Iso8601}, +}; +use std::{ + fmt::{Display, Formatter, Result as FmtResult}, + str::FromStr, +}; +pub use time::{Date, Time}; + +pub struct Datetime { + time: UtcDateTime, + derived_string: String, +} + +impl Datetime { + pub fn now() -> Result { + Datetime::from_utc(UtcDateTime::now()) + } + + pub fn new(date: Date, time: Time) -> Result { + Datetime::from_utc(UtcDateTime::new(date, time)) + } + + fn from_utc(utc: UtcDateTime) -> Result { + Ok(Datetime { + time: utc, + derived_string: match utc.format(&Rfc3339) { + Ok(ds) => ds, + Err(e) => return Err(Error::Format { + err: FormatError::Datetime(e), + object: "Datetime".to_string(), + }), + }, + }) + } +} + +impl Display for Datetime { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{}", self.derived_string) + } +} + +impl FromStr for Datetime { + type Err = Error; + + fn from_str(s: &str) -> Result { + // Parse as both Rfc3339 and Iso8601 to ensure intersection + let dt_rfc3339 = UtcDateTime::parse(s, &Rfc3339); + let dt_iso8601 = UtcDateTime::parse(s, &Iso8601::DEFAULT); + + let datetime = dt_iso8601 + .and(dt_rfc3339) + .map_err(|e| Error::Parse { + err: ParseError::Datetime(e), + object: "Datetime".to_string(), + })?; + + Ok(Datetime { time: datetime, derived_string: s.to_string() }) + } +} diff --git a/atproto/src/types/did.rs b/atproto/src/types/did.rs new file mode 100644 index 0000000..6c91db2 --- /dev/null +++ b/atproto/src/types/did.rs @@ -0,0 +1,72 @@ +use crate::error::{Error, ParseError}; +use std::{ + fmt::{Display, Formatter, Result as FmtResult}, + str::FromStr, +}; +use lazy_regex::regex_captures; + +enum DidMethod { + Web, + Plc, +} +impl Display for DidMethod { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{}", match self { + DidMethod::Web => String::from("web"), + DidMethod::Plc => String::from("plc"), + }) + } +} +impl FromStr for DidMethod { + type Err = Error; + fn from_str(s: &str) -> Result { + match s { + "web" => Ok(DidMethod::Web), + "plc" => Ok(DidMethod::Plc), + _ => Err(Error::Parse { + err: ParseError::Format, + object: "DidMethod".to_string(), + }), + } + } +} + +pub struct Did { + method: DidMethod, + identifier: String, +} + +impl Display for Did { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "did:{}:{}", self.method, self.identifier) + } +} + +impl FromStr for Did { + type Err = Error; + fn from_str(s: &str) -> Result { + if s.len() > 2048 { + return Err(Error::Parse { + err: ParseError::Length { max: s.len(), got: 2048 }, + object: "Did".to_string(), + }); + } + + let Some(( + _whole, unchecked_method, identifier + )) = regex_captures!( + r"^did:([a-z]+):([a-zA-Z0-9._:%-]*[a-zA-Z0-9._-])$", + s, + ) else { + return Err(Error::Parse { + err: ParseError::Format, + object: "Did".to_string(), + }); + }; + + Ok(Self { + method: unchecked_method.parse::()?, + identifier: identifier.to_string(), + }) + } +} diff --git a/atproto/src/types/record_key.rs b/atproto/src/types/record_key.rs new file mode 100644 index 0000000..7368b53 --- /dev/null +++ b/atproto/src/types/record_key.rs @@ -0,0 +1,63 @@ +use crate::{ + types::{Nsid, Tid}, + error::{Error, ParseError}, +}; +use std::{ + fmt::{Display, Formatter, Result as FmtResult}, + str::FromStr, +}; +use lazy_regex::regex_is_match; + +pub enum RecordKey { + Tid(Tid), + Nsid(Nsid), + Literal(String), + Any(String), +} +impl Display for RecordKey { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{}", match self { + RecordKey::Tid(tid) => tid.to_string(), + RecordKey::Nsid(nsid) => nsid.to_string(), + RecordKey::Literal(literal) => literal.to_string(), + RecordKey::Any(any) => any.to_string(), + }) + } +} +impl FromStr for RecordKey { + type Err = Error; + fn from_str(s: &str) -> Result { + if s.len() > 512 { + return Err(Error::Parse { + err: ParseError::Length { max: 512, got: s.len() }, + object: "RecordKey".to_string(), + }); + } + + if !( + regex_is_match!( + r"^[a-zA-Z0-9`.-_:~]+$", + s, + ) + && s != "." + && s != ".." + ) { + return Err(Error::Parse { + err: ParseError::Format, + object: "RecordKey".to_string(), + }); + } + + // Valid record key, now decide type + if s.starts_with("literal:") { + return Ok(RecordKey::Literal(s.to_string())); + } + if let Ok(tid) = s.parse::() { + return Ok(RecordKey::Tid(tid)); + } + if let Ok(nsid) = s.parse::() { + return Ok(RecordKey::Nsid(nsid)); + } + Ok(RecordKey::Any(s.to_string())) + } +} diff --git a/atproto/src/types/uri.rs b/atproto/src/types/uri.rs new file mode 100644 index 0000000..843ee2b --- /dev/null +++ b/atproto/src/types/uri.rs @@ -0,0 +1,79 @@ +use crate::{ + types::{Did, Authority, Nsid, RecordKey}, + error::{Error, ParseError}, +}; +use std::{ + fmt::{Display, Formatter, Result as FmtResult}, + str::FromStr, +}; +use lazy_regex::regex_captures; + +pub struct Uri { + authority: Did, + collection: Option, + rkey: Option, +} + +impl Display for Uri { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "at://{}", self.authority)?; + + if let Some(collection) = &self.collection { + write!(f, "/{}", collection)?; + + if let Some(rkey) = &self.rkey { + write!(f, "/{}", rkey)?; + } + } + + Ok(()) + } +} + +impl FromStr for Uri { + type Err = Error; + fn from_str(s: &str) -> Result { + if s.len() > 8000 { + return Err(Error::Parse { + err: ParseError::Length { max: 8000, got: s.len() }, + object: "Did".to_string(), + }); + } + + let Some(( + _whole, unchecked_authority, unchecked_collection, unchecked_rkey + )) = regex_captures!( + r"/^at:\/\/([\w\.\-_~:]+)(?:\/([\w\.\-_~:]+)(?:)\/([\w\.\-_~:]+))?$/i", + s, + ) else { + return Err(Error::Parse { + err: ParseError::Format, + object: "Uri".to_string(), + }); + }; + + let did = match Authority::from_str(unchecked_authority)? { + Authority::Handle(h) => + return Err(Error::Parse { + err: ParseError::ForceDid { handle: h.to_string() }, + object: "Uri".to_string(), + }), + Authority::Did(d) => d, + }; + + let collection = if unchecked_collection.is_empty() { None } + else { Some(unchecked_collection.parse::()?) }; + + let rkey = if unchecked_rkey.is_empty() { None } + else { Some(unchecked_rkey.parse::()?) }; + + Ok(Uri { authority: did, collection, rkey }) + } +} + +impl Uri { + pub fn authority_as_did(&self) -> &Did { + &self.authority + } +} + diff --git a/db/Cargo.toml b/db/Cargo.toml index dd1d0b9..13b55ae 100644 --- a/db/Cargo.toml +++ b/db/Cargo.toml @@ -4,8 +4,8 @@ version = "0.1.0" edition = "2024" [dependencies] -thiserror = "2.0.12" atproto.workspace = true async-trait.workspace = true -sqlx = { version = "0.8.6", features = ["postgres", "runtime-tokio"] } +sqlx.workspace = true +thiserror.workspace = true tokio.workspace = true