Compare commits

...

4 commits

13 changed files with 523 additions and 48 deletions

5
.gitignore vendored Normal file
View file

@ -0,0 +1,5 @@
.env
.pds-data/
target/
tmp/

37
Cargo.lock generated
View file

@ -79,6 +79,18 @@ dependencies = [
"tracing-subscriber", "tracing-subscriber",
] ]
[[package]]
name = "argon2"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072"
dependencies = [
"base64ct",
"blake2",
"cpufeatures",
"password-hash",
]
[[package]] [[package]]
name = "async-lock" name = "async-lock"
version = "3.4.0" version = "3.4.0"
@ -283,6 +295,15 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "blake2"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe"
dependencies = [
"digest",
]
[[package]] [[package]]
name = "block-buffer" name = "block-buffer"
version = "0.10.4" version = "0.10.4"
@ -667,11 +688,16 @@ dependencies = [
name = "entryway" name = "entryway"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"argon2",
"async-trait",
"atproto", "atproto",
"http 1.3.1", "http 1.3.1",
"router", "router",
"serde", "serde",
"serde_json", "serde_json",
"sqlx",
"thiserror 2.0.12",
"time",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@ -1589,6 +1615,17 @@ dependencies = [
"windows-targets 0.52.6", "windows-targets 0.52.6",
] ]
[[package]]
name = "password-hash"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166"
dependencies = [
"base64ct",
"rand_core",
"subtle",
]
[[package]] [[package]]
name = "pem-rfc7468" name = "pem-rfc7468"
version = "0.7.0" version = "0.7.0"

View file

@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2024" edition = "2024"
[dependencies] [dependencies]
atproto.workspace = true atproto = { workspace = true, features = ["sqlx-support"] }
router.workspace = true router.workspace = true
http = "1.3.1" http = "1.3.1"
serde.workspace = true serde.workspace = true
@ -12,3 +12,8 @@ serde_json.workspace = true
tokio.workspace = true tokio.workspace = true
tracing-subscriber.workspace = true tracing-subscriber.workspace = true
tracing.workspace = true tracing.workspace = true
async-trait.workspace = true
sqlx.workspace = true
thiserror.workspace = true
argon2 = "0.5"
time = { version = "0.3", features = ["formatting", "macros"] }

View file

@ -0,0 +1,23 @@
-- PDS Entryway Account Management Schema
-- Minimal schema for account creation and authentication
-- Actor table - stores public identity information
CREATE TABLE actor (
did VARCHAR PRIMARY KEY,
handle VARCHAR,
created_at VARCHAR NOT NULL
);
-- Case-insensitive unique index on handle
CREATE UNIQUE INDEX actor_handle_lower_idx ON actor (LOWER(handle));
-- Account table - stores private authentication data
CREATE TABLE account (
did VARCHAR PRIMARY KEY,
email VARCHAR NOT NULL,
password_scrypt VARCHAR NOT NULL,
email_confirmed_at VARCHAR
);
-- Case-insensitive unique index on email
CREATE UNIQUE INDEX account_email_lower_idx ON account (LOWER(email));

View file

@ -0,0 +1,23 @@
-- PDS Entryway Account Management Schema
-- Minimal schema for account creation and authentication
-- Actor table - stores public identity information
CREATE TABLE actor (
did VARCHAR PRIMARY KEY,
handle VARCHAR,
created_at VARCHAR NOT NULL
);
-- Case-insensitive unique index on handle
CREATE UNIQUE INDEX actor_handle_lower_idx ON actor (LOWER(handle));
-- Account table - stores private authentication data
CREATE TABLE account (
did VARCHAR PRIMARY KEY,
email VARCHAR NOT NULL,
password_scrypt VARCHAR NOT NULL,
email_confirmed_at VARCHAR
);
-- Case-insensitive unique index on email
CREATE UNIQUE INDEX account_email_lower_idx ON account (LOWER(email));

View file

@ -0,0 +1,13 @@
use thiserror::Error;
#[derive(Error, Debug)]
pub enum DatabaseError {
#[error("Database connection error: {0}")]
Connection(#[from] sqlx::Error),
#[error("Handle already taken: {0}")]
HandleTaken(String),
#[error("Email already taken: {0}")]
EmailTaken(String),
}

View file

@ -0,0 +1,5 @@
pub mod error;
pub mod operations;
pub use error::DatabaseError;
pub use operations::Database;

View file

@ -0,0 +1,82 @@
use sqlx::{Pool, Postgres};
use atproto::types::{
Handle
Did
};
use crate::database::DatabaseError;
pub struct Database {
pool: Pool<Postgres>,
}
impl Database {
pub fn new(pool: Pool<Postgres>) -> Self {
Self { pool }
}
// Account availability checking
pub async fn check_handle_available(&self, handle: &Handle) -> Result<bool, DatabaseError> {
let count = sqlx::query_scalar!(
"SELECT COUNT(*) FROM actor WHERE LOWER(handle) = LOWER($1)",
handle
)
.fetch_one(&self.pool)
.await?;
Ok(count.unwrap_or(0) == 0)
}
pub async fn check_email_available(&self, email: &str) -> Result<bool, DatabaseError> {
let count = sqlx::query_scalar!(
"SELECT COUNT(*) FROM account WHERE LOWER(email) = LOWER($1)",
email
)
.fetch_one(&self.pool)
.await?;
Ok(count.unwrap_or(0) == 0)
}
// Account creation
pub async fn create_account(
&self,
did: &Did,
handle: &Handle,
email: &str,
password_hash: &str,
created_at: &str,
) -> Result<(), DatabaseError> {
// Use a transaction to ensure both actor and account records are created together
let mut tx = self.pool.begin().await?;
// Insert into actor table
sqlx::query!(
r#"
INSERT INTO actor (did, handle, created_at)
VALUES ($1, $2, $3)
"#,
did,
handle,
created_at
)
.execute(&mut *tx)
.await?;
// Insert into account table
sqlx::query!(
r#"
INSERT INTO account (did, email, password_scrypt)
VALUES ($1, $2, $3)
"#,
did,
email,
password_hash
)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}
}

View file

@ -1,21 +1,17 @@
use router::{ use router::{
Router, Router,
xrpc::{ xrpc::XrpcEndpoint,
XrpcEndpoint,
ProcedureInput,
Response,
error,
},
}; };
use serde::Deserialize;
use atproto::types::Nsid; use atproto::types::Nsid;
use http::status::StatusCode; use sqlx::{Pool, Postgres};
use tracing::{ use std::env;
event, use tracing::{event, Level};
instrument,
Level, mod xrpc;
}; mod database;
use std::fmt::Debug;
use xrpc::create_account;
use database::Database;
struct Config { struct Config {
entryway_url: String, entryway_url: String,
@ -29,6 +25,36 @@ async fn main() {
let subscriber = tracing_subscriber::FmtSubscriber::new(); let subscriber = tracing_subscriber::FmtSubscriber::new();
let _ = tracing::subscriber::set_global_default(subscriber); let _ = tracing::subscriber::set_global_default(subscriber);
// Set up database connection
let database_url = env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://localhost/entryway_dev".to_string());
event!(Level::INFO, "Connecting to database: {}", database_url);
let pool = match sqlx::postgres::PgPoolOptions::new()
.max_connections(5)
.connect(&database_url)
.await
{
Ok(pool) => pool,
Err(e) => {
event!(Level::ERROR, "Failed to connect to database: {}", e);
std::process::exit(1);
}
};
let database = Database::new(pool);
// Run migrations
if let Err(e) = database.run_migrations().await {
event!(Level::ERROR, "Failed to run migrations: {}", e);
std::process::exit(1);
}
event!(Level::INFO, "Database setup complete");
// TODO: Wire up database to XRPC handlers
// For now, keeping the existing router setup
let mut router = Router::new(); let mut router = Router::new();
let create_account_nsid: Nsid = "com.atproto.server.createAccount".parse::<Nsid>().expect("valid nsid"); let create_account_nsid: Nsid = "com.atproto.server.createAccount".parse::<Nsid>().expect("valid nsid");
router = router.add_endpoint(XrpcEndpoint::not_implemented()); router = router.add_endpoint(XrpcEndpoint::not_implemented());
@ -36,22 +62,3 @@ async fn main() {
router.serve().await; router.serve().await;
} }
#[derive(Deserialize, Debug)]
struct CreateAccountInput {
email: Option<String>,
handle: String,
did: Option<String>,
invite_code: Option<String>,
verification_code: Option<String>,
verification_phone: Option<String>,
password: Option<String>,
recovery_key: Option<String>,
plc_op: Option<String>,
}
#[instrument]
async fn create_account(data: ProcedureInput<CreateAccountInput>) -> Response {
event!(Level::INFO, "In create_account");
error(StatusCode::OK, "error", "message")
}

View file

@ -0,0 +1,258 @@
use router::xrpc::{ProcedureInput, Response, error};
use serde::{Deserialize, Serialize};
use http::status::StatusCode;
use tracing::{event, instrument, Level};
use atproto::types::Handle;
use std::str::FromStr;
use argon2::{Argon2, PasswordHasher, password_hash::{rand_core::OsRng, SaltString}};
use time::OffsetDateTime;
use crate::database::{Database, DatabaseError};
#[derive(Deserialize, Debug)]
pub struct CreateAccountInput {
pub email: Option<String>,
pub handle: String,
pub did: Option<String>,
pub invite_code: Option<String>,
pub verification_code: Option<String>,
pub verification_phone: Option<String>,
pub password: Option<String>,
pub recovery_key: Option<String>,
pub plc_op: Option<String>,
}
#[derive(Serialize, Debug)]
pub struct CreateAccountResponse {
pub handle: String,
pub did: String,
// pub did_doc: Option<DidDocument>, // TODO: Define DidDocument type
pub access_jwt: String,
pub refresh_jwt: String,
}
#[instrument]
pub async fn create_account(data: ProcedureInput<CreateAccountInput>) -> Response {
event!(Level::INFO, "Creating account for handle: {}", data.input.handle);
// TODO: Get database from context/config
// For now, this won't compile but shows the intended flow
// 1. Input validation
let validated_input = match validate_inputs(&data.input).await {
Ok(input) => input,
Err(err) => return err,
};
// 2. Check handle and email availability
// if let Err(err) = check_availability(&database, &validated_input).await {
// return err;
// }
// 3. Generate DID (placeholder for now)
let did = generate_placeholder_did(&validated_input.handle).await;
// 4. Hash password if provided
let password_hash = if let Some(password) = &validated_input.password {
match hash_password(password) {
Ok(hash) => Some(hash),
Err(_) => {
return error(
StatusCode::INTERNAL_SERVER_ERROR,
"InternalServerError",
"Failed to hash password"
);
}
}
} else {
None
};
// 5. Create account in database
let created_at = OffsetDateTime::now_utc().format(&time::format_description::well_known::Iso8601::DEFAULT)
.unwrap_or_else(|_| "unknown".to_string());
// if let Err(err) = create_account_in_db(&database, &did, &validated_input, password_hash.as_deref(), &created_at).await {
// return convert_db_error_to_response(err);
// }
// 6. Generate session tokens (placeholder for now)
let credentials = Credentials {
access_jwt: "placeholder_access_token".to_string(),
refresh_jwt: "placeholder_refresh_token".to_string(),
};
// Return success response
let response = CreateAccountResponse {
handle: validated_input.handle.clone(),
did: did.clone(),
access_jwt: credentials.access_jwt,
refresh_jwt: credentials.refresh_jwt,
};
event!(Level::INFO, "Account created successfully for DID: {} with handle: {}", did, validated_input.handle);
// TODO: Replace with proper JSON response encoding
error(StatusCode::OK, "success", "Account created successfully")
}
// Maximum password length (matches atproto TypeScript implementation)
const NEW_PASSWORD_MAX_LENGTH: usize = 256;
// TODO: Implement these helper functions
async fn validate_inputs(input: &CreateAccountInput) -> Result<ValidatedInput, Response> {
// Based on validateInputsForLocalPds in the TypeScript version
// Validate email is provided and has basic format
let email = match &input.email {
Some(e) if !e.is_empty() => e.clone(),
_ => {
return Err(error(
StatusCode::BAD_REQUEST,
"InvalidRequest",
"Email is required"
));
}
};
// Validate email format (basic validation for now)
// TODO: Improve email validation - add proper RFC validation and disposable email checking
// TypeScript version uses @hapi/address for validation and disposable-email-domains-js for disposable check
if !is_valid_email(&email) {
return Err(error(
StatusCode::BAD_REQUEST,
"InvalidRequest",
"This email address is not supported, please use a different email."
));
}
// Validate password length if provided
if let Some(password) = &input.password {
if password.len() > NEW_PASSWORD_MAX_LENGTH {
return Err(error(
StatusCode::BAD_REQUEST,
"InvalidRequest",
&format!("Password too long. Maximum length is {} characters.", NEW_PASSWORD_MAX_LENGTH)
));
}
}
// Validate and normalize handle using atproto types
let handle = Handle::from_str(&input.handle).map_err(|_| {
error(
StatusCode::BAD_REQUEST,
"InvalidRequest",
"Invalid handle format"
)
})?;
// TODO: Invite codes - not supported for now but leave placeholder
if input.invite_code.is_some() {
event!(Level::INFO, "Invite codes not yet supported, ignoring");
}
Ok(ValidatedInput {
handle: handle.to_string(),
email: email.to_lowercase(), // Normalize email to lowercase
password: input.password.clone(),
invite_code: input.invite_code.clone(),
})
}
// Basic email validation - checks for @ and . in reasonable positions
// TODO: Replace with proper email validation library like email-address crate
fn is_valid_email(email: &str) -> bool {
// Very basic email validation
let at_pos = email.find('@');
let last_dot_pos = email.rfind('.');
match (at_pos, last_dot_pos) {
(Some(at), Some(dot)) => {
// @ must come before the last dot
// Must have content before @, between @ and dot, and after dot
at > 0 && dot > at + 1 && dot < email.len() - 1
}
_ => false,
}
}
async fn check_availability(database: &Database, input: &ValidatedInput) -> Result<(), Response> {
// Check that handle and email are not already taken
match database.check_handle_available(&input.handle).await {
Ok(false) => {
return Err(error(
StatusCode::BAD_REQUEST,
"InvalidRequest",
&format!("Handle already taken: {}", input.handle)
));
}
Err(err) => {
event!(Level::ERROR, "Database error checking handle availability: {:?}", err);
return Err(error(
StatusCode::INTERNAL_SERVER_ERROR,
"InternalServerError",
"Database error"
));
}
_ => {}
}
match database.check_email_available(&input.email).await {
Ok(false) => {
return Err(error(
StatusCode::BAD_REQUEST,
"InvalidRequest",
&format!("Email already taken: {}", input.email)
));
}
Err(err) => {
event!(Level::ERROR, "Database error checking email availability: {:?}", err);
return Err(error(
StatusCode::INTERNAL_SERVER_ERROR,
"InternalServerError",
"Database error"
));
}
_ => {}
}
Ok(())
}
async fn generate_placeholder_did(handle: &str) -> String {
// TODO: Replace with actual DID generation (did:plc)
// For now, generate a placeholder DID based on handle
format!("did:placeholder:{}", handle.replace(".", "-"))
}
fn hash_password(password: &str) -> Result<String, argon2::password_hash::Error> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2.hash_password(password.as_bytes(), &salt)?;
Ok(password_hash.to_string())
}
async fn create_account_in_db(
database: &Database,
did: &str,
input: &ValidatedInput,
password_hash: Option<&str>,
created_at: &str,
) -> Result<(), DatabaseError> {
let hash = password_hash.unwrap_or(""); // Empty hash if no password
database.create_account(did, &input.handle, &input.email, hash, created_at).await
}
#[derive(Debug)]
struct ValidatedInput {
handle: String,
email: String,
password: Option<String>,
invite_code: Option<String>,
}
#[derive(Debug)]
struct Credentials {
access_jwt: String,
refresh_jwt: String,
}

3
entryway/src/xrpc/mod.rs Normal file
View file

@ -0,0 +1,3 @@
pub mod create_account;
pub use create_account::create_account;

View file

@ -18,13 +18,21 @@ impl Default for Router {
Self::new() Self::new()
} }
} }
impl Router { impl<S> Router<S>
where
S: Clone + Send + Sync + 'static,
{
pub fn new() -> Self { pub fn new() -> Self {
let router = AxumRouter::new(); let router = AxumRouter::new();
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127,0,0,1)), 6702); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127,0,0,1)), 6702);
Router { router, addr } Router { router, addr }
} }
pub fn with_state<S2>(mut self, state: S) -> Router<S2> {
self.router = self.with_state(S);
self
}
pub fn add_endpoint<E: Endpoint>(mut self, endpoint: E) -> Self { pub fn add_endpoint<E: Endpoint>(mut self, endpoint: E) -> Self {
self.router = endpoint.add_to_router(self.router); self.router = endpoint.add_to_router(self.router);
self self

View file

@ -52,8 +52,11 @@ pub fn response(code: StatusCode, message: &str) -> Response {
error(code, "", message) error(code, "", message)
} }
pub struct QueryInput { pub struct QueryInput<S = ()>
where S: Clone + Send + Sync + 'static,
{
pub parameters: HashMap<String, String>, pub parameters: HashMap<String, String>,
pub state: S,
} }
impl<S> FromRequestParts<S> for QueryInput impl<S> FromRequestParts<S> for QueryInput
where where
@ -61,23 +64,26 @@ where
{ {
type Rejection = Response; type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) async fn from_request_parts(parts: &mut Parts, state: &S)
-> Result<Self, Self::Rejection> { -> Result<Self, Self::Rejection> {
let query_params: Result<Query<HashMap<String, String>>, QueryRejection> = Query::try_from_uri(&parts.uri); let query_params: Result<Query<HashMap<String, String>>, QueryRejection> = Query::try_from_uri(&parts.uri);
match query_params { match query_params {
Ok(p) => Ok(QueryInput { parameters: p.0 }), Ok(p) => Ok(QueryInput { parameters: p.0, state }),
Err(e) => Err(error(StatusCode::BAD_REQUEST, "Bad Parameters", &e.body_text())), Err(e) => Err(error(StatusCode::BAD_REQUEST, "Bad Parameters", &e.body_text())),
} }
} }
} }
#[derive(Debug)] #[derive(Debug)]
pub struct ProcedureInput<J> { pub struct ProcedureInput<J, S = ()>
where S: Clone + Send + Sync + 'static,
{
pub parameters: HashMap<String, String>, pub parameters: HashMap<String, String>,
pub input: J, pub input: J,
pub state: S,
} }
impl<J, S> FromRequest<S> for ProcedureInput<J> impl<J, S> FromRequest<S> for ProcedureInput<J, S>
where where
J: for<'de> serde::Deserialize<'de> + Send + 'static, J: for<'de> serde::Deserialize<'de> + Send + 'static,
Bytes: FromRequest<S>, Bytes: FromRequest<S>,
@ -95,7 +101,7 @@ where
.map(|Json(v)| v) .map(|Json(v)| v)
.map_err(|e| error(StatusCode::BAD_REQUEST, "Bad Parameters", &e.body_text()))?; .map_err(|e| error(StatusCode::BAD_REQUEST, "Bad Parameters", &e.body_text()))?;
Ok(ProcedureInput { parameters, input }) Ok(ProcedureInput { parameters, input, state })
} }
} }
@ -125,14 +131,14 @@ where
} }
impl XrpcEndpoint { impl XrpcEndpoint {
pub fn new_query<Q>(nsid: Nsid, query: Q) -> Self pub fn new_query<Q, S>(nsid: Nsid, query: Q) -> Self
where where
Q: XrpcHandler<QueryInput> + Clone Q: XrpcHandler<QueryInput> + Clone
{ {
XrpcEndpoint { XrpcEndpoint {
path: Path::Nsid(nsid), path: Path::Nsid(nsid),
resolver: get(async move | mut parts: Parts | -> Response { resolver: get(async move | mut parts: Parts, state: &S | -> Response {
match QueryInput::from_request_parts(&mut parts, &()).await { match QueryInput<S>::from_request_parts(&mut parts, state).await {
Ok(qi) => query.call(qi).await, Ok(qi) => query.call(qi).await,
Err(e) => e Err(e) => e
} }
@ -140,15 +146,15 @@ impl XrpcEndpoint {
} }
} }
pub fn new_procedure<P, J>(nsid: Nsid, procedure: P) -> Self pub fn new_procedure<P, J, S>(nsid: Nsid, procedure: P) -> Self
where where
P: XrpcHandler<ProcedureInput<J>> + Clone, P: XrpcHandler<ProcedureInput<J, S>> + Clone,
J: for<'de> serde::Deserialize<'de> + Send + 'static, J: for<'de> serde::Deserialize<'de> + Send + 'static,
{ {
XrpcEndpoint { XrpcEndpoint {
path: Path::Nsid(nsid), path: Path::Nsid(nsid),
resolver: post(async move | req: Request | -> Response { resolver: post(async move | req: Request, state: &S | -> Response {
match ProcedureInput::<J>::from_request(req, &()).await { match ProcedureInput::<J, S>::from_request(req, &state).await {
Ok(pi) => procedure.call(pi).await, Ok(pi) => procedure.call(pi).await,
Err(e) => e Err(e) => e
} }