From 3c5154d44936a4d792eb965b5004a38106b0ceb5 Mon Sep 17 00:00:00 2001 From: asonix Date: Thu, 19 Mar 2020 17:19:05 -0500 Subject: [PATCH] Use single connection pool --- Cargo.lock | 1 + Cargo.toml | 1 + src/db.rs | 102 ++++++++++++++++++++++++------ src/db_actor.rs | 164 ------------------------------------------------ src/error.rs | 39 ++++++------ src/inbox.rs | 2 +- src/label.rs | 35 ----------- src/main.rs | 18 ++---- src/state.rs | 58 ++++------------- 9 files changed, 123 insertions(+), 297 deletions(-) delete mode 100644 src/db_actor.rs delete mode 100644 src/label.rs diff --git a/Cargo.lock b/Cargo.lock index 4a32929..ab4e8e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -387,6 +387,7 @@ dependencies = [ "http-signature-normalization-actix", "log", "lru", + "num_cpus", "pretty_env_logger", "rand", "rsa", diff --git a/Cargo.toml b/Cargo.toml index 9584ec2..4dc1f23 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ futures = "0.3.4" http-signature-normalization-actix = { version = "0.3.0-alpha.5", default-features = false, features = ["sha-2"] } log = "0.4" lru = "0.4.3" +num_cpus = "1.12" pretty_env_logger = "0.4.0" rand = "0.7" rsa = "0.2" diff --git a/src/db.rs b/src/db.rs index eac926c..4791430 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,16 +1,80 @@ +use crate::error::MyError; use activitystreams::primitives::XsdAnyUri; -use anyhow::Error; -use bb8_postgres::tokio_postgres::{row::Row, Client}; +use bb8_postgres::{ + bb8, + tokio_postgres::{row::Row, Client, Config, NoTls}, + PostgresConnectionManager, +}; use log::{info, warn}; use rsa::RSAPrivateKey; use rsa_pem::KeyExt; -use std::collections::HashSet; +use std::{collections::HashSet, convert::TryInto}; -#[derive(Clone, Debug, thiserror::Error)] -#[error("No host present in URI")] -pub struct HostError; +pub type Pool = bb8::Pool>; -pub async fn listen(client: &Client) -> Result<(), Error> { +#[derive(Clone)] +pub struct Db { + pool: Pool, +} + +impl Db { + pub async fn build(config: Config) -> Result { + let manager = PostgresConnectionManager::new(config, NoTls); + + let pool = bb8::Pool::builder() + .max_size((num_cpus::get() * 4).try_into()?) + .build(manager) + .await?; + + Ok(Db { pool }) + } + + pub async fn remove_listener(&self, inbox: XsdAnyUri) -> Result<(), MyError> { + let conn = self.pool.get().await?; + + remove_listener(&conn, &inbox).await?; + Ok(()) + } + + pub async fn add_listener(&self, inbox: XsdAnyUri) -> Result<(), MyError> { + let conn = self.pool.get().await?; + + add_listener(&conn, &inbox).await?; + Ok(()) + } + + pub async fn hydrate_blocks(&self) -> Result, MyError> { + let conn = self.pool.get().await?; + + Ok(hydrate_blocks(&conn).await?) + } + + pub async fn hydrate_whitelists(&self) -> Result, MyError> { + let conn = self.pool.get().await?; + + Ok(hydrate_whitelists(&conn).await?) + } + + pub async fn hydrate_listeners(&self) -> Result, MyError> { + let conn = self.pool.get().await?; + + Ok(hydrate_listeners(&conn).await?) + } + + pub async fn hydrate_private_key(&self) -> Result, MyError> { + let conn = self.pool.get().await?; + + Ok(hydrate_private_key(&conn).await?) + } + + pub async fn update_private_key(&self, private_key: &RSAPrivateKey) -> Result<(), MyError> { + let conn = self.pool.get().await?; + + Ok(update_private_key(&conn, private_key).await?) + } +} + +pub async fn listen(client: &Client) -> Result<(), MyError> { info!("LISTEN new_blocks;"); info!("LISTEN new_whitelists;"); info!("LISTEN new_listeners;"); @@ -31,7 +95,7 @@ pub async fn listen(client: &Client) -> Result<(), Error> { Ok(()) } -pub async fn hydrate_private_key(client: &Client) -> Result, Error> { +async fn hydrate_private_key(client: &Client) -> Result, MyError> { info!("SELECT value FROM settings WHERE key = 'private_key'"); let rows = client .query("SELECT value FROM settings WHERE key = 'private_key'", &[]) @@ -45,7 +109,7 @@ pub async fn hydrate_private_key(client: &Client) -> Result Result<(), Error> { +async fn update_private_key(client: &Client, key: &RSAPrivateKey) -> Result<(), MyError> { let pem_pkcs8 = key.to_pem_pkcs8()?; info!("INSERT INTO settings (key, value, created_at) VALUES ('private_key', $1::TEXT, 'now');"); @@ -53,11 +117,11 @@ pub async fn update_private_key(client: &Client, key: &RSAPrivateKey) -> Result< Ok(()) } -pub async fn add_block(client: &Client, block: &XsdAnyUri) -> Result<(), Error> { +async fn add_block(client: &Client, block: &XsdAnyUri) -> Result<(), MyError> { let host = if let Some(host) = block.as_url().host() { host } else { - return Err(HostError.into()); + return Err(MyError::Host(block.to_string())); }; info!( @@ -74,11 +138,11 @@ pub async fn add_block(client: &Client, block: &XsdAnyUri) -> Result<(), Error> Ok(()) } -pub async fn add_whitelist(client: &Client, whitelist: &XsdAnyUri) -> Result<(), Error> { +async fn add_whitelist(client: &Client, whitelist: &XsdAnyUri) -> Result<(), MyError> { let host = if let Some(host) = whitelist.as_url().host() { host } else { - return Err(HostError.into()); + return Err(MyError::Host(whitelist.to_string())); }; info!( @@ -95,7 +159,7 @@ pub async fn add_whitelist(client: &Client, whitelist: &XsdAnyUri) -> Result<(), Ok(()) } -pub async fn remove_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), Error> { +async fn remove_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), MyError> { info!( "DELETE FROM listeners WHERE actor_id = {};", listener.as_str() @@ -110,7 +174,7 @@ pub async fn remove_listener(client: &Client, listener: &XsdAnyUri) -> Result<() Ok(()) } -pub async fn add_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), Error> { +async fn add_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), MyError> { info!( "INSERT INTO listeners (actor_id, created_at) VALUES ($1::TEXT, 'now'); [{}]", listener.as_str(), @@ -125,14 +189,14 @@ pub async fn add_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), E Ok(()) } -pub async fn hydrate_blocks(client: &Client) -> Result, Error> { +async fn hydrate_blocks(client: &Client) -> Result, MyError> { info!("SELECT domain_name FROM blocks"); let rows = client.query("SELECT domain_name FROM blocks", &[]).await?; parse_rows(rows) } -pub async fn hydrate_whitelists(client: &Client) -> Result, Error> { +async fn hydrate_whitelists(client: &Client) -> Result, MyError> { info!("SELECT domain_name FROM whitelists"); let rows = client .query("SELECT domain_name FROM whitelists", &[]) @@ -141,14 +205,14 @@ pub async fn hydrate_whitelists(client: &Client) -> Result, Erro parse_rows(rows) } -pub async fn hydrate_listeners(client: &Client) -> Result, Error> { +async fn hydrate_listeners(client: &Client) -> Result, MyError> { info!("SELECT actor_id FROM listeners"); let rows = client.query("SELECT actor_id FROM listeners", &[]).await?; parse_rows(rows) } -fn parse_rows(rows: Vec) -> Result, Error> +fn parse_rows(rows: Vec) -> Result, MyError> where T: std::str::FromStr + Eq + std::hash::Hash, E: std::fmt::Display, diff --git a/src/db_actor.rs b/src/db_actor.rs deleted file mode 100644 index e9d2af6..0000000 --- a/src/db_actor.rs +++ /dev/null @@ -1,164 +0,0 @@ -use crate::{ - db::{add_listener, remove_listener}, - error::MyError, - label::ArbiterLabel, -}; -use activitystreams::primitives::XsdAnyUri; -use actix::prelude::*; -use bb8_postgres::{bb8, tokio_postgres, PostgresConnectionManager}; -use log::{error, info}; -use tokio::sync::oneshot::{channel, Receiver}; - -#[derive(Clone)] -pub struct Db { - actor: Addr, -} - -pub type Pool = bb8::Pool>; - -pub enum DbActorState { - Waiting(tokio_postgres::Config), - Ready(Pool), -} - -pub struct DbActor { - pool: DbActorState, -} - -pub struct DbQuery(pub F); - -impl Db { - pub fn new(config: tokio_postgres::Config) -> Db { - let actor = Supervisor::start(|_| DbActor { - pool: DbActorState::new_empty(config), - }); - - Db { actor } - } - - pub async fn execute_inline(&self, f: F) -> Result - where - T: Send + 'static, - F: FnOnce(Pool) -> Fut + Send + 'static, - Fut: Future, - { - Ok(self.actor.send(DbQuery(f)).await?.await?) - } - - pub async fn remove_listener(&self, inbox: XsdAnyUri) -> Result<(), MyError> { - self.execute_inline(move |pool: Pool| { - let inbox = inbox.clone(); - - async move { - let conn = pool.get().await?; - - remove_listener(&conn, &inbox).await - } - }) - .await? - .map_err(MyError::from) - } - - pub async fn add_listener(&self, inbox: XsdAnyUri) -> Result<(), MyError> { - self.execute_inline(move |pool: Pool| { - let inbox = inbox.clone(); - - async move { - let conn = pool.get().await?; - - add_listener(&conn, &inbox).await - } - }) - .await? - .map_err(MyError::from) - } -} - -impl DbActorState { - pub fn new_empty(config: tokio_postgres::Config) -> Self { - DbActorState::Waiting(config) - } - - pub async fn new(config: tokio_postgres::Config) -> Result { - let manager = PostgresConnectionManager::new(config, tokio_postgres::tls::NoTls); - let pool = bb8::Pool::builder().max_size(8).build(manager).await?; - - Ok(DbActorState::Ready(pool)) - } -} - -impl Actor for DbActor { - type Context = Context; - - fn started(&mut self, ctx: &mut Self::Context) { - info!("Starting DB Actor in {}", ArbiterLabel::get()); - match self.pool { - DbActorState::Waiting(ref config) => { - let fut = - DbActorState::new(config.clone()) - .into_actor(self) - .map(|res, actor, ctx| { - match res { - Ok(pool) => { - info!("DB pool created in {}", ArbiterLabel::get()); - actor.pool = pool; - } - Err(e) => { - error!( - "Error starting DB Actor in {}, {}", - ArbiterLabel::get(), - e - ); - ctx.stop(); - } - }; - }); - - ctx.wait(fut); - } - _ => (), - }; - } -} - -impl Supervised for DbActor {} - -impl Handler> for DbActor -where - F: FnOnce(Pool) -> Fut + 'static, - Fut: Future, - R: Send + 'static, -{ - type Result = ResponseFuture>; - - fn handle(&mut self, msg: DbQuery, ctx: &mut Self::Context) -> Self::Result { - let (tx, rx) = channel(); - - let pool = match self.pool { - DbActorState::Ready(ref pool) => pool.clone(), - _ => { - error!("Tried to query DB before ready"); - return Box::pin(async move { rx }); - } - }; - - ctx.spawn( - async move { - let result = (msg.0)(pool).await; - let _ = tx.send(result); - } - .into_actor(self), - ); - - Box::pin(async move { rx }) - } -} - -impl Message for DbQuery -where - F: FnOnce(Pool) -> Fut, - Fut: Future, - R: Send + 'static, -{ - type Result = Receiver; -} diff --git a/src/error.rs b/src/error.rs index 93f1419..4d394d7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,15 +1,13 @@ use activitystreams::primitives::XsdAnyUriError; -use actix::MailboxError; use actix_web::{error::ResponseError, http::StatusCode, HttpResponse}; use log::error; use rsa_pem::KeyError; use std::{convert::Infallible, io::Error}; -use tokio::sync::oneshot::error::RecvError; #[derive(Debug, thiserror::Error)] pub enum MyError { #[error("Error in db, {0}")] - DbError(#[from] anyhow::Error), + DbError(#[from] bb8_postgres::tokio_postgres::error::Error), #[error("Couldn't parse key, {0}")] Key(#[from] KeyError), @@ -32,9 +30,6 @@ pub enum MyError { #[error("Couldn't parse the signature header")] HeaderValidation(#[from] actix_web::http::header::InvalidHeaderValue), - #[error("Failed to get output of db operation")] - Oneshot(#[from] RecvError), - #[error("Couldn't decode base64")] Base64(#[from] base64::DecodeError), @@ -56,11 +51,14 @@ pub enum MyError { #[error("Wrong ActivityPub kind, {0}")] Kind(String), - #[error("The requested actor's mailbox is closed")] - MailboxClosed, + #[error("No host present in URI, {0}")] + Host(String), - #[error("The requested actor's mailbox has timed out")] - MailboxTimeout, + #[error("Too many CPUs, {0}")] + CpuCount(#[from] std::num::TryFromIntError), + + #[error("Timed out while waiting on db pool")] + DbTimeout, #[error("Invalid algorithm provided to verifier")] Algorithm, @@ -104,6 +102,18 @@ impl ResponseError for MyError { } } +impl From> for MyError +where + T: Into, +{ + fn from(e: bb8_postgres::bb8::RunError) -> Self { + match e { + bb8_postgres::bb8::RunError::User(e) => e.into(), + bb8_postgres::bb8::RunError::TimedOut => MyError::DbTimeout, + } + } +} + impl From for MyError { fn from(i: Infallible) -> Self { match i {} @@ -115,12 +125,3 @@ impl From for MyError { MyError::Rsa(e) } } - -impl From for MyError { - fn from(m: MailboxError) -> MyError { - match m { - MailboxError::Closed => MyError::MailboxClosed, - MailboxError::Timeout => MyError::MailboxTimeout, - } - } -} diff --git a/src/inbox.rs b/src/inbox.rs index 49c964e..cbfaf14 100644 --- a/src/inbox.rs +++ b/src/inbox.rs @@ -1,7 +1,7 @@ use crate::{ accepted, apub::{AcceptedActors, AcceptedObjects, ValidTypes}, - db_actor::Db, + db::Db, error::MyError, requests::Requests, state::{State, UrlKind}, diff --git a/src/label.rs b/src/label.rs deleted file mode 100644 index 2c0108b..0000000 --- a/src/label.rs +++ /dev/null @@ -1,35 +0,0 @@ -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, -}; - -#[derive(Clone, Debug)] -pub struct ArbiterLabelFactory(Arc); - -#[derive(Clone, Debug)] -pub struct ArbiterLabel(usize); - -impl ArbiterLabelFactory { - pub fn new() -> Self { - ArbiterLabelFactory(Arc::new(AtomicUsize::new(0))) - } - - pub fn set_label(&self) { - if !actix::Arbiter::contains_item::() { - let id = self.0.fetch_add(1, Ordering::SeqCst); - actix::Arbiter::set_item(ArbiterLabel(id)); - } - } -} - -impl ArbiterLabel { - pub fn get() -> ArbiterLabel { - actix::Arbiter::get_item(|label: &ArbiterLabel| label.clone()) - } -} - -impl std::fmt::Display for ArbiterLabel { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Arbiter #{}", self.0) - } -} diff --git a/src/main.rs b/src/main.rs index cc2ca1f..c063268 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,10 +10,8 @@ use sha2::{Digest, Sha256}; mod apub; mod db; -mod db_actor; mod error; mod inbox; -mod label; mod nodeinfo; mod notify; mod requests; @@ -23,9 +21,8 @@ mod webfinger; use self::{ apub::PublicKey, - db_actor::Db, + db::Db, error::MyError, - label::ArbiterLabelFactory, state::{State, UrlKind}, verifier::MyVerify, webfinger::RelayResolver, @@ -95,25 +92,18 @@ async fn main() -> Result<(), anyhow::Error> { let use_whitelist = std::env::var("USE_WHITELIST").is_ok(); let use_https = std::env::var("USE_HTTPS").is_ok(); - let arbiter_labeler = ArbiterLabelFactory::new(); + let db = Db::build(pg_config.clone()).await?; - let db = Db::new(pg_config.clone()); - arbiter_labeler.clone().set_label(); - - let state: State = db - .execute_inline(move |pool| State::hydrate(use_https, use_whitelist, hostname, pool)) - .await??; + let state = State::hydrate(use_https, use_whitelist, hostname, &db).await?; let _ = notify::NotifyHandler::start_handler(state.clone(), pg_config.clone()); HttpServer::new(move || { - arbiter_labeler.clone().set_label(); let state = state.clone(); - let actor = Db::new(pg_config.clone()); App::new() .wrap(Logger::default()) - .data(actor) + .data(db.clone()) .data(state.clone()) .data(state.requests()) .service(web::resource("/").route(web::get().to(index))) diff --git a/src/state.rs b/src/state.rs index 5053785..1688cbf 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,9 +1,7 @@ -use crate::{apub::AcceptedActors, db_actor::Pool, requests::Requests}; +use crate::{apub::AcceptedActors, db::Db, error::MyError, requests::Requests}; use activitystreams::primitives::XsdAnyUri; -use anyhow::Error; -use bb8_postgres::tokio_postgres::Client; use futures::try_join; -use log::{error, info}; +use log::info; use lru::LruCache; use rand::thread_rng; use rsa::{RSAPrivateKey, RSAPublicKey}; @@ -44,28 +42,21 @@ pub enum UrlKind { Outbox, } -#[derive(Clone, Debug, thiserror::Error)] -#[error("Error generating RSA key")] -pub struct RsaError; - impl Settings { async fn hydrate( - client: &Client, + db: &Db, use_https: bool, whitelist_enabled: bool, hostname: String, - ) -> Result { - let private_key = if let Some(key) = crate::db::hydrate_private_key(client).await? { + ) -> Result { + let private_key = if let Some(key) = db.hydrate_private_key().await? { key } else { info!("Generating new keys"); let mut rng = thread_rng(); - let key = RSAPrivateKey::new(&mut rng, 4096).map_err(|e| { - error!("Error generating RSA key, {}", e); - RsaError - })?; + let key = RSAPrivateKey::new(&mut rng, 4096)?; - crate::db::update_private_key(client, &key).await?; + db.update_private_key(&key).await?; key }; @@ -249,35 +240,12 @@ impl State { use_https: bool, whitelist_enabled: bool, hostname: String, - pool: Pool, - ) -> Result { - let pool1 = pool.clone(); - let pool2 = pool.clone(); - let pool3 = pool.clone(); - - let f1 = async move { - let conn = pool.get().await?; - - crate::db::hydrate_blocks(&conn).await - }; - - let f2 = async move { - let conn = pool1.get().await?; - - crate::db::hydrate_whitelists(&conn).await - }; - - let f3 = async move { - let conn = pool2.get().await?; - - crate::db::hydrate_listeners(&conn).await - }; - - let f4 = async move { - let conn = pool3.get().await?; - - Settings::hydrate(&conn, use_https, whitelist_enabled, hostname).await - }; + db: &Db, + ) -> Result { + let f1 = db.hydrate_blocks(); + let f2 = db.hydrate_whitelists(); + let f3 = db.hydrate_listeners(); + let f4 = Settings::hydrate(db, use_https, whitelist_enabled, hostname); let (blocks, whitelists, listeners, settings) = try_join!(f1, f2, f3, f4)?;