From 37e680a058202188387d6d75885cf48ed26f6046 Mon Sep 17 00:00:00 2001 From: TommyLike Date: Sun, 27 Aug 2023 10:17:00 +0800 Subject: [PATCH 1/2] Upgrade sqlx into sea-orm for token&user&clusterkey database model --- Cargo.toml | 1 + src/infra/database/model/clusterkey/dto.rs | 52 ++++-------- .../database/model/clusterkey/repository.rs | 73 +++++++++-------- src/infra/database/model/token/dto.rs | 47 +++-------- src/infra/database/model/token/repository.rs | 81 ++++++++++--------- src/infra/database/model/user/dto.rs | 25 +++--- src/infra/database/model/user/repository.rs | 59 +++++++------- src/infra/database/pool.rs | 28 +++++++ src/infra/sign_backend/factory.rs | 6 +- src/infra/sign_backend/memory/backend.rs | 6 +- src/presentation/server/control_server.rs | 8 +- src/presentation/server/data_server.rs | 8 +- src/util/error.rs | 6 ++ 13 files changed, 203 insertions(+), 197 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2efd5be..9755f54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,7 @@ regex = "1" csrf= "0.4.1" data-encoding= "2.4.0" enum-iterator= "1.4.1" +sea-orm = { version = "0.12.2", features = [ "sqlx-mysql", "runtime-tokio-rustls", "macros", "with-chrono"] } [build-dependencies] tonic-build = "0.8.4" diff --git a/src/infra/database/model/clusterkey/dto.rs b/src/infra/database/model/clusterkey/dto.rs index 659268c..d16183f 100644 --- a/src/infra/database/model/clusterkey/dto.rs +++ b/src/infra/database/model/clusterkey/dto.rs @@ -18,10 +18,13 @@ use crate::domain::clusterkey::entity::ClusterKey; use sqlx::types::chrono; -use sqlx::FromRow; +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; -#[derive(Debug, FromRow)] -pub(super) struct ClusterKeyDTO { +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] +#[sea_orm(table_name = "cluster_key")] +pub struct Model { + #[sea_orm(primary_key)] pub id: i32, pub data: Vec, pub algorithm: String, @@ -29,8 +32,13 @@ pub(super) struct ClusterKeyDTO { pub create_at: chrono::DateTime, } -impl From for ClusterKey { - fn from(dto: ClusterKeyDTO) -> Self { +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} + +impl From for ClusterKey { + fn from(dto: Model) -> Self { ClusterKey { id: dto.id, data: dto.data, @@ -41,44 +49,14 @@ impl From for ClusterKey { } } -impl From for ClusterKeyDTO { - fn from(cluster_key: ClusterKey) -> Self { - Self { - id: cluster_key.id, - data: cluster_key.data, - algorithm: cluster_key.algorithm, - identity: cluster_key.identity, - create_at: cluster_key.create_at, - } - } -} - #[cfg(test)] mod tests { use chrono::Utc; - use super::{ClusterKey,ClusterKeyDTO}; - - #[test] - fn test_cluster_key_dto_from_entity() { - let key = ClusterKey { - id: 1, - data: vec![1, 2, 3], - algorithm: "algo".to_string(), - identity: "id".to_string(), - create_at: Utc::now() - }; - let create_at = key.create_at.clone(); - let dto = ClusterKeyDTO::from(key); - assert_eq!(dto.id, 1); - assert_eq!(dto.data, vec![1, 2, 3]); - assert_eq!(dto.algorithm, "algo"); - assert_eq!(dto.identity, "id"); - assert_eq!(dto.create_at, create_at); - } + use super::{ClusterKey,Model}; #[test] fn test_cluster_key_entity_from_dto() { - let dto = ClusterKeyDTO { + let dto = Model { id: 1, data: vec![1, 2, 3], algorithm: "algo".to_string(), diff --git a/src/infra/database/model/clusterkey/repository.rs b/src/infra/database/model/clusterkey/repository.rs index e34357e..a18cb0b 100644 --- a/src/infra/database/model/clusterkey/repository.rs +++ b/src/infra/database/model/clusterkey/repository.rs @@ -14,23 +14,24 @@ * */ -use super::dto::ClusterKeyDTO; -use crate::infra::database::pool::DbPool; +use super::dto::Entity as ClusterKeyDTO; +use crate::infra::database::model::clusterkey; +use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, ActiveValue::Set, QueryOrder}; use crate::domain::clusterkey::entity::ClusterKey; use crate::domain::clusterkey::repository::Repository; -use crate::util::error::Result; +use crate::util::error::{Result, Error}; use async_trait::async_trait; -use std::boxed::Box; +use sea_orm::sea_query::OnConflict; #[derive(Clone)] pub struct ClusterKeyRepository { - db_pool: DbPool, + db_connection: DatabaseConnection, } impl ClusterKeyRepository { - pub fn new(db_pool: DbPool) -> Self { + pub fn new(db_connection: DatabaseConnection) -> Self { Self { - db_pool, + db_connection, } } } @@ -38,43 +39,47 @@ impl ClusterKeyRepository { #[async_trait] impl Repository for ClusterKeyRepository { async fn create(&self, cluster_key: ClusterKey) -> Result<()> { - let dto = ClusterKeyDTO::from(cluster_key); - let _ : Option = sqlx::query_as("INSERT IGNORE INTO cluster_key(data, algorithm, identity, create_at) VALUES (?, ?, ?, ?)") - .bind(&dto.data) - .bind(&dto.algorithm) - .bind(&dto.identity) - .bind(dto.create_at) - .fetch_optional(&self.db_pool) - .await?; + let cluster_key = clusterkey::dto::ActiveModel { + data: Set(cluster_key.data), + algorithm: Set(cluster_key.algorithm), + identity: Set(cluster_key.identity), + create_at: Set(cluster_key.create_at), + ..Default::default() + }; + //TODO: https://github.com/SeaQL/sea-orm/issues/1790 + ClusterKeyDTO::insert(cluster_key).on_conflict(OnConflict::new() + .update_column(clusterkey::dto::Column::Id).to_owned() + ).exec(&self.db_connection).await?; Ok(()) } async fn get_latest(&self, algorithm: &str) -> Result> { - let latest: Option = sqlx::query_as( - "SELECT * FROM cluster_key WHERE algorithm = ? ORDER BY id DESC LIMIT 1", - ) - .bind(algorithm) - .fetch_optional(&self.db_pool) - .await?; - match latest { - Some(l) => return Ok(Some(ClusterKey::from(l))), - None => Ok(None), + match ClusterKeyDTO::find().filter( + clusterkey::dto::Column::Algorithm.eq(algorithm) + ).order_by_desc(clusterkey::dto::Column::Id).one( + &self.db_connection).await? { + None => { + Ok(None) + } + Some(cluster_key) => { + Ok(Some(ClusterKey::from(cluster_key))) + } } } async fn get_by_id(&self, id: i32) -> Result { - let selected: ClusterKeyDTO = sqlx::query_as("SELECT * FROM cluster_key WHERE id = ?") - .bind(id) - .fetch_one(&self.db_pool) - .await?; - Ok(ClusterKey::from(selected)) + match ClusterKeyDTO::find_by_id(id).one(&self.db_connection).await? { + None => { + Err(Error::NotFoundError) + } + Some(cluster_key) => { + Ok(ClusterKey::from(cluster_key)) + } + } } - async fn delete_by_id(&self, id: i32) -> Result<()> { - let _: Option = sqlx::query_as("DELETE FROM cluster_key where id = ?") - .bind(id) - .fetch_optional(&self.db_pool) - .await?; + let _ = ClusterKeyDTO::delete_by_id( + id).exec(&self.db_connection).await?; Ok(()) } } diff --git a/src/infra/database/model/token/dto.rs b/src/infra/database/model/token/dto.rs index 12084a6..0b3f5b0 100644 --- a/src/infra/database/model/token/dto.rs +++ b/src/infra/database/model/token/dto.rs @@ -13,14 +13,16 @@ * * // See the Mulan PSL v2 for more details. * */ -use sqlx::FromRow; use chrono::{DateTime, Utc}; use crate::domain::token::entity::Token; -use crate::util::key::get_token_hash; +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; -#[derive(Debug, FromRow, Clone)] -pub(super) struct TokenDTO { +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] +#[sea_orm(table_name = "token")] +pub struct Model { + #[sea_orm(primary_key)] pub id: i32, pub user_id: i32, pub description: String, @@ -29,21 +31,13 @@ pub(super) struct TokenDTO { pub expire_at: DateTime, } -impl From for TokenDTO { - fn from(token: Token) -> Self { - Self { - id: token.id, - user_id: token.user_id, - description: token.description.clone(), - token: get_token_hash(&token.token), - create_at: token.create_at, - expire_at: token.expire_at, - } - } -} +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} -impl From for Token { - fn from(dto: TokenDTO) -> Self { +impl From for Token { + fn from(dto: Model) -> Self { Self { id: dto.id, user_id: dto.user_id, @@ -59,25 +53,10 @@ impl From for Token { mod tests { use super::*; use chrono::Utc; - - #[test] - fn test_token_dto_from_entity() { - let token = Token::new(1, "Test token".to_string(), "abc123".to_string()).unwrap(); - let token_hash = get_token_hash(&token.token); - let dto = TokenDTO::from(token.clone()); - assert_eq!(dto.id, token.id); - assert_eq!(dto.user_id, token.user_id); - assert_eq!(dto.description, token.description); - assert_ne!(dto.token, token.token); - assert_eq!(dto.token, token_hash); - assert_eq!(dto.create_at, token.create_at); - assert_eq!(dto.expire_at, token.expire_at); - } - #[test] fn test_token_entity_from_dto() { let now = Utc::now(); - let dto = TokenDTO { + let dto = Model { id: 1, user_id: 2, description: "Test token".to_string(), diff --git a/src/infra/database/model/token/repository.rs b/src/infra/database/model/token/repository.rs index 4677be7..2bba298 100644 --- a/src/infra/database/model/token/repository.rs +++ b/src/infra/database/model/token/repository.rs @@ -14,80 +14,85 @@ * */ -use crate::infra::database::pool::DbPool; use crate::domain::token::entity::{Token}; use crate::domain::token::repository::Repository; use crate::util::error::Result; use async_trait::async_trait; -use std::boxed::Box; - -use crate::infra::database::model::token::dto::TokenDTO; +use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, ActiveValue::Set, Condition, ActiveModelTrait}; +use crate::infra::database::model::token; +use crate::infra::database::model::token::dto::Entity as TokenDTO; +use crate::util::error; use crate::util::key::get_token_hash; #[derive(Clone)] pub struct TokenRepository { - db_pool: DbPool, + db_connection: DatabaseConnection } impl TokenRepository { - pub fn new(db_pool: DbPool) -> Self { + pub fn new(db_connection: DatabaseConnection) -> Self { Self { - db_pool, + db_connection, } } } #[async_trait] impl Repository for TokenRepository { - async fn create(&self, token: Token) -> Result { - let dto = TokenDTO::from(token); - let record : u64 = sqlx::query("INSERT INTO token(user_id, description, token, create_at, expire_at) VALUES (?, ?, ?, ?, ?)") - .bind(dto.user_id) - .bind(&dto.description) - .bind(&dto.token) - .bind(dto.create_at) - .bind(dto.expire_at) - .execute(&self.db_pool) - .await?.last_insert_id(); - self.get_token_by_id(record as i32).await + let token = token::dto::ActiveModel { + user_id: Set(token.user_id), + description: Set(token.description), + token: Set(get_token_hash(&token.token)), + create_at:Set(token.create_at), + expire_at: Set(token.expire_at), + ..Default::default() + }; + Ok(Token::from(token.insert(&self.db_connection).await?)) } async fn get_token_by_id(&self, id: i32) -> Result { - let selected: TokenDTO = sqlx::query_as("SELECT * FROM token WHERE id = ?") - .bind(id) - .fetch_one(&self.db_pool) - .await?; - Ok(Token::from(selected)) + match TokenDTO::find_by_id(id).one(&self.db_connection).await? { + None => { + Err(error::Error::NotFoundError) + } + Some(token) => { + Ok(Token::from(token)) + } + } } async fn get_token_by_value(&self, token: &str) -> Result { - let selected: TokenDTO = sqlx::query_as("SELECT * FROM token WHERE token = ?") - .bind(get_token_hash(token)) - .fetch_one(&self.db_pool) - .await?; - Ok(Token::from(selected)) + match TokenDTO::find().filter( + token::dto::Column::Token.eq(get_token_hash(token))).one( + &self.db_connection).await? { + None => { + Err(error::Error::NotFoundError) + } + Some(token) => { + Ok(Token::from(token)) + } + } + + } async fn delete_by_user_and_id(&self, id: i32, user_id: i32) -> Result<()> { - let _: Option = sqlx::query_as("DELETE FROM token where id = ? AND user_id = ?") - .bind(id) - .bind(user_id) - .fetch_optional(&self.db_pool) + let _ = TokenDTO::delete_many().filter(Condition::all() + .add(token::dto::Column::Id.eq(id)) + .add(token::dto::Column::UserId.eq(user_id))).exec(&self.db_connection) .await?; Ok(()) } async fn get_token_by_user_id(&self, id: i32) -> Result> { - let dtos: Vec = sqlx::query_as("SELECT * FROM token WHERE user_id = ?") - .bind(id) - .fetch_all(&self.db_pool) - .await?; + let tokens = TokenDTO::find().filter( + token::dto::Column::UserId.eq(id)).all(&self.db_connection).await?; let mut results = vec![]; - for dto in dtos.into_iter() { + for dto in tokens.into_iter() { results.push(Token::from(dto)); } Ok(results) } -} +} \ No newline at end of file diff --git a/src/infra/database/model/user/dto.rs b/src/infra/database/model/user/dto.rs index c9011aa..a2489a9 100644 --- a/src/infra/database/model/user/dto.rs +++ b/src/infra/database/model/user/dto.rs @@ -13,26 +13,25 @@ * * // See the Mulan PSL v2 for more details. * */ -use sqlx::FromRow; use crate::domain::user::entity::User; +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; -#[derive(Debug, FromRow)] -pub(super) struct UserDTO { +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] +#[sea_orm(table_name = "user")] +pub struct Model { + #[sea_orm(primary_key)] pub id: i32, pub email: String } -impl From for UserDTO { - fn from(user: User) -> Self { - Self { - id: user.id, - email: user.email, - } - } -} +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} -impl From for User { - fn from(dto: UserDTO) -> Self { +impl From for User { + fn from(dto: Model) -> Self { Self { id: dto.id, email: dto.email diff --git a/src/infra/database/model/user/repository.rs b/src/infra/database/model/user/repository.rs index 4f1775e..d476dd9 100644 --- a/src/infra/database/model/user/repository.rs +++ b/src/infra/database/model/user/repository.rs @@ -14,24 +14,23 @@ * */ -use super::dto::UserDTO; - -use crate::infra::database::pool::DbPool; +use super::dto::Entity as UserDTO; +use crate::infra::database::model::user; +use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, ActiveValue::Set, ActiveModelTrait}; use crate::domain::user::entity::User; use crate::domain::user::repository::Repository; -use crate::util::error::Result; +use crate::util::error::{Error, Result}; use async_trait::async_trait; -use std::boxed::Box; #[derive(Clone)] pub struct UserRepository { - db_pool: DbPool, + db_connection: DatabaseConnection } impl UserRepository { - pub fn new(db_pool: DbPool) -> Self { + pub fn new(db_connection: DatabaseConnection) -> Self { Self { - db_pool, + db_connection, } } } @@ -45,36 +44,42 @@ impl Repository for UserRepository { Ok(existed) } Err(_err) => { - let dto = UserDTO::from(user); - let record : u64 = sqlx::query("INSERT INTO user(email) VALUES (?)") - .bind(&dto.email) - .execute(&self.db_pool) - .await?.last_insert_id(); - self.get_by_id(record as i32).await + let user = user::dto::ActiveModel { + email: Set(user.email), + ..Default::default() + }; + Ok(User::from(user.insert(&self.db_connection).await?)) } } } async fn get_by_id(&self, id: i32) -> Result { - let selected: UserDTO = sqlx::query_as("SELECT * FROM user WHERE id = ?") - .bind(id) - .fetch_one(&self.db_pool) - .await?; - Ok(User::from(selected)) + match UserDTO::find_by_id(id).one( + &self.db_connection).await? { + None => { + Err(Error::NotFoundError) + } + Some(user) => { + Ok(User::from(user)) + } + } } async fn get_by_email(&self, email: &str) -> Result { - let selected: UserDTO = sqlx::query_as("SELECT * FROM user WHERE email = ?") - .bind(email) - .fetch_one(&self.db_pool) - .await?; - Ok(User::from(selected)) + match UserDTO::find().filter( + user::dto::Column::Email.eq(email)).one( + &self.db_connection).await? { + None => { + Err(Error::NotFoundError) + } + Some(user) => { + Ok(User::from(user)) + } + } } async fn delete_by_id(&self, id: i32) -> Result<()> { - let _: Option = sqlx::query_as("DELETE FROM user where id = ?") - .bind(id) - .fetch_optional(&self.db_pool) + let _ = UserDTO::delete_by_id(id).exec(&self.db_connection) .await?; Ok(()) } diff --git a/src/infra/database/pool.rs b/src/infra/database/pool.rs index d36503f..bbddf63 100644 --- a/src/infra/database/pool.rs +++ b/src/infra/database/pool.rs @@ -20,11 +20,16 @@ use once_cell::sync::OnceCell; use sqlx::mysql::{MySql, MySqlPoolOptions}; use sqlx::pool::Pool; use std::collections::HashMap; +use std::time::Duration; +use sea_orm::{DatabaseConnection, ConnectOptions, Database}; use crate::util::error::{Error, Result}; pub type DbPool = Pool; +//Now we have database pool for sqlx framework and database connection for sea-orm framework, +//db_pool will be removed when all database operations has been upgraded into sea-orm static DB_POOL: OnceCell = OnceCell::new(); +static DB_CONNECTION: OnceCell = OnceCell::new(); pub async fn create_pool(config: &HashMap) -> Result<()> { let max_connections: u32 = config @@ -54,6 +59,20 @@ pub async fn create_pool(config: &HashMap) -> Result<()> { .await .map_err(Error::from)?; DB_POOL.set(pool).expect("db pool configured"); + + //initialize the database connection + let mut opt = ConnectOptions::new(db_connection); + opt.max_connections(max_connections) + .min_connections(5) + .connect_timeout(Duration::from_secs(8)) + .acquire_timeout(Duration::from_secs(8)) + .idle_timeout(Duration::from_secs(8)) + .max_lifetime(Duration::from_secs(8)) + .sqlx_logging(true) + .sqlx_logging_level(log::LevelFilter::Info); + + DB_CONNECTION.set(Database::connect(opt).await?).expect("database connection configured"); + get_db_connection()?.ping().await.expect("database connection failed"); ping().await?; Ok(()) } @@ -67,6 +86,15 @@ pub fn get_db_pool() -> Result { }; } +pub fn get_db_connection() -> Result { + return match DB_CONNECTION.get() { + None => Err(Error::DatabaseError( + "failed to get database pool".to_string(), + )), + Some(pool) => Ok(pool.clone()), + }; +} + pub async fn ping() -> Result<()> { info!("Checking on database connection..."); let pool = get_db_pool(); diff --git a/src/infra/sign_backend/factory.rs b/src/infra/sign_backend/factory.rs index 98b04ba..a422d2a 100644 --- a/src/infra/sign_backend/factory.rs +++ b/src/infra/sign_backend/factory.rs @@ -19,19 +19,19 @@ use crate::domain::sign_service::{SignBackend, SignBackendType}; use crate::util::error::{Result}; use std::sync::{Arc, RwLock}; use config::{Config}; -use crate::infra::database::pool::DbPool; +use sea_orm::DatabaseConnection; use crate::infra::sign_backend::memory::backend::MemorySignBackend; pub struct SignBackendFactory {} impl SignBackendFactory { - pub async fn new_engine(config: Arc>, db_pool: DbPool) -> Result> { + pub async fn new_engine(config: Arc>, db_connection: DatabaseConnection) -> Result> { let engine_type = SignBackendType::from_str( config.read()?.get_string("sign-backend.type")?.as_str(), )?; info!("sign backend configured with plugin {:?}", engine_type); match engine_type { - SignBackendType::Memory => Ok(Box::new(MemorySignBackend::new(config, db_pool).await?)), + SignBackendType::Memory => Ok(Box::new(MemorySignBackend::new(config, db_connection).await?)), } } } diff --git a/src/infra/sign_backend/memory/backend.rs b/src/infra/sign_backend/memory/backend.rs index a1b51b4..5322465 100644 --- a/src/infra/sign_backend/memory/backend.rs +++ b/src/infra/sign_backend/memory/backend.rs @@ -24,7 +24,6 @@ use config::Config; use std::sync::RwLock; use crate::infra::database::model::clusterkey::repository; -use crate::infra::database::pool::{DbPool}; use crate::infra::kms::factory; use crate::infra::encryption::engine::{EncryptionEngineWithClusterKey}; use crate::domain::encryption_engine::EncryptionEngine; @@ -34,6 +33,7 @@ use crate::domain::datakey::entity::DataKey; use crate::util::error::{Error, Result}; use async_trait::async_trait; use chrono::{DateTime, Utc}; +use sea_orm::DatabaseConnection; use crate::infra::encryption::algorithm::factory::AlgorithmFactory; @@ -50,13 +50,13 @@ impl MemorySignBackend { /// 2. initialize the cluster repo /// 2. initialize the encryption engine including the cluster key /// 3. initialize the signing plugins - pub async fn new(server_config: Arc>, db_pool: DbPool) -> Result { + pub async fn new(server_config: Arc>, db_connection: DatabaseConnection) -> Result { //initialize the kms backend let kms_provider = factory::KMSProviderFactory::new_provider( &server_config.read()?.get_table("memory.kms-provider")? )?; let repository = - repository::ClusterKeyRepository::new(db_pool); + repository::ClusterKeyRepository::new(db_connection); let engine_config = server_config.read()?.get_table("memory.encryption-engine")?; let encryptor = AlgorithmFactory::new_algorithm( &engine_config diff --git a/src/presentation/server/control_server.rs b/src/presentation/server/control_server.rs index 3ad0b50..439c2e2 100644 --- a/src/presentation/server/control_server.rs +++ b/src/presentation/server/control_server.rs @@ -31,7 +31,7 @@ use time::Duration as timeDuration; use std::time::Duration; use crate::infra::database::model::datakey::repository as datakeyRepository; -use crate::infra::database::pool::{create_pool, get_db_pool}; +use crate::infra::database::pool::{create_pool, get_db_connection, get_db_pool}; use crate::presentation::handler::control::*; use actix_web::{dev::ServiceRequest}; use actix_web::cookie::SameSite; @@ -123,10 +123,10 @@ impl ControlServer { get_db_pool()?, ); let sign_backend = SignBackendFactory::new_engine( - server_config.clone(), get_db_pool()?).await?; + server_config.clone(), get_db_connection()?).await?; //initialize repos - let user_repo = UserRepository::new(get_db_pool()?); - let token_repo = TokenRepository::new(get_db_pool()?); + let user_repo = UserRepository::new(get_db_connection()?); + let token_repo = TokenRepository::new(get_db_connection()?); //initialize the service let user_service = Arc::new( diff --git a/src/presentation/server/data_server.rs b/src/presentation/server/data_server.rs index 2727ebe..e073de0 100644 --- a/src/presentation/server/data_server.rs +++ b/src/presentation/server/data_server.rs @@ -31,7 +31,7 @@ use crate::application::user::{DBUserService, UserService}; use crate::infra::database::model::datakey::repository; use crate::infra::database::model::token::repository::TokenRepository; use crate::infra::database::model::user::repository::UserRepository; -use crate::infra::database::pool::{create_pool, get_db_pool}; +use crate::infra::database::pool::{create_pool, get_db_connection, get_db_pool}; use crate::infra::sign_backend::factory::SignBackendFactory; @@ -107,12 +107,12 @@ impl DataServer { let mut server = Server::builder(); info!("data server starts"); let sign_backend = SignBackendFactory::new_engine( - self.server_config.clone(), get_db_pool()?).await?; + self.server_config.clone(), get_db_connection()?).await?; let data_repository = repository::DataKeyRepository::new( get_db_pool()?); let key_service = DBKeyService::new(data_repository, sign_backend); - let user_repo = UserRepository::new(get_db_pool()?); - let token_repo = TokenRepository::new(get_db_pool()?); + let user_repo = UserRepository::new(get_db_connection()?); + let token_repo = TokenRepository::new(get_db_connection()?); let user_service = DBUserService::new(user_repo, token_repo, self.server_config.clone())?; key_service.start_cache_cleanup_loop(self.cancel_token.clone())?; diff --git a/src/util/error.rs b/src/util/error.rs index 3e4492c..448b2f0 100644 --- a/src/util/error.rs +++ b/src/util/error.rs @@ -45,6 +45,7 @@ use anyhow::Error as AnyhowError; use csrf::CsrfError; use utoipa::{ToSchema}; use efi_signer::error::Error as EFIError; +use sea_orm::DbErr; pub type Result = std::result::Result; @@ -402,4 +403,9 @@ impl From> for Error { fn from(error: Vec) -> Self { Error::KeyParseError(format!("original vec {:?}", error)) } } +impl From for Error { + fn from(error: DbErr) -> Self { Error::DatabaseError(error.to_string()) } +} + + -- Gitee From b1bdae7ab687f08bedce7f1f6d8f32ba1436f234 Mon Sep 17 00:00:00 2001 From: TommyLike Date: Thu, 31 Aug 2023 19:47:23 +0800 Subject: [PATCH 2/2] Add testcase for repository --- Cargo.toml | 1 + src/domain/clusterkey/entity.rs | 3 +- src/domain/token/entity.rs | 2 +- src/domain/user/entity.rs | 2 +- .../database/model/clusterkey/repository.rs | 172 +++++++++++- src/infra/database/model/token/repository.rs | 244 +++++++++++++++++- src/infra/database/model/user/repository.rs | 164 +++++++++++- src/infra/database/pool.rs | 4 +- src/infra/sign_backend/factory.rs | 2 +- src/infra/sign_backend/memory/backend.rs | 2 +- 10 files changed, 557 insertions(+), 39 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9755f54..c9c198f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -84,6 +84,7 @@ tonic-build = "0.8.4" [dev-dependencies] mockito = "1.0.2" +sea-orm = { version = "0.12.2", features = [ "mock"] } [[bin]] name = "client" diff --git a/src/domain/clusterkey/entity.rs b/src/domain/clusterkey/entity.rs index f038edf..564735b 100644 --- a/src/domain/clusterkey/entity.rs +++ b/src/domain/clusterkey/entity.rs @@ -22,7 +22,7 @@ use std::vec::Vec; use crate::domain::kms_provider::KMSProvider; -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct ClusterKey { pub id: i32, pub data: Vec, @@ -89,7 +89,6 @@ impl Default for SecClusterKey { } } - impl SecClusterKey { pub async fn load(cluster_key: ClusterKey, kms_provider: &Box) -> Result where K: KMSProvider + ?Sized { diff --git a/src/domain/token/entity.rs b/src/domain/token/entity.rs index d6087bc..6b65e2c 100644 --- a/src/domain/token/entity.rs +++ b/src/domain/token/entity.rs @@ -21,7 +21,7 @@ use std::fmt::{Display, Formatter}; const TOKEN_EXPIRE_IN_DAYS: i64 = 180; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct Token { pub id: i32, pub user_id: i32, diff --git a/src/domain/user/entity.rs b/src/domain/user/entity.rs index 0541e37..ad72d75 100644 --- a/src/domain/user/entity.rs +++ b/src/domain/user/entity.rs @@ -19,7 +19,7 @@ use std::fmt::{Display, Formatter}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct User { pub id: i32, pub email: String diff --git a/src/infra/database/model/clusterkey/repository.rs b/src/infra/database/model/clusterkey/repository.rs index a18cb0b..edf8334 100644 --- a/src/infra/database/model/clusterkey/repository.rs +++ b/src/infra/database/model/clusterkey/repository.rs @@ -24,12 +24,12 @@ use async_trait::async_trait; use sea_orm::sea_query::OnConflict; #[derive(Clone)] -pub struct ClusterKeyRepository { - db_connection: DatabaseConnection, +pub struct ClusterKeyRepository<'a> { + db_connection: &'a DatabaseConnection, } -impl ClusterKeyRepository { - pub fn new(db_connection: DatabaseConnection) -> Self { +impl<'a> ClusterKeyRepository<'a> { + pub fn new(db_connection: &'a DatabaseConnection) -> Self { Self { db_connection, } @@ -37,7 +37,7 @@ impl ClusterKeyRepository { } #[async_trait] -impl Repository for ClusterKeyRepository { +impl<'a> Repository for ClusterKeyRepository<'a> { async fn create(&self, cluster_key: ClusterKey) -> Result<()> { let cluster_key = clusterkey::dto::ActiveModel { data: Set(cluster_key.data), @@ -49,7 +49,7 @@ impl Repository for ClusterKeyRepository { //TODO: https://github.com/SeaQL/sea-orm/issues/1790 ClusterKeyDTO::insert(cluster_key).on_conflict(OnConflict::new() .update_column(clusterkey::dto::Column::Id).to_owned() - ).exec(&self.db_connection).await?; + ).exec(self.db_connection).await?; Ok(()) } @@ -57,7 +57,7 @@ impl Repository for ClusterKeyRepository { match ClusterKeyDTO::find().filter( clusterkey::dto::Column::Algorithm.eq(algorithm) ).order_by_desc(clusterkey::dto::Column::Id).one( - &self.db_connection).await? { + self.db_connection).await? { None => { Ok(None) } @@ -68,7 +68,7 @@ impl Repository for ClusterKeyRepository { } async fn get_by_id(&self, id: i32) -> Result { - match ClusterKeyDTO::find_by_id(id).one(&self.db_connection).await? { + match ClusterKeyDTO::find_by_id(id).one(self.db_connection).await? { None => { Err(Error::NotFoundError) } @@ -79,7 +79,161 @@ impl Repository for ClusterKeyRepository { } async fn delete_by_id(&self, id: i32) -> Result<()> { let _ = ClusterKeyDTO::delete_by_id( - id).exec(&self.db_connection).await?; + id).exec(self.db_connection).await?; Ok(()) } } + +#[cfg(test)] +mod tests { + use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction}; + use crate::domain::clusterkey::entity::ClusterKey; + use crate::domain::clusterkey::repository::Repository; + use crate::infra::database::model::clusterkey::dto; + use crate::util::error::Result; + use crate::infra::database::model::clusterkey::repository::{ClusterKeyRepository}; + + #[tokio::test] + async fn test_cluster_key_repository_create_sql_statement() -> Result<()> { + let now = chrono::Utc::now(); + let db = MockDatabase::new(DatabaseBackend::MySql) + .append_query_results([ + vec![dto::Model { + id: 0, + data: vec![], + algorithm: "".to_string(), + identity: "".to_string(), + create_at: now.clone(), + }], + ]).append_exec_results([ + MockExecResult{ + last_insert_id: 1, + rows_affected: 1, + } + ]).into_connection(); + + let key_repository = ClusterKeyRepository::new(&db); + let key = ClusterKey{ + id: 0, + data: vec![], + algorithm: "fake_algorithm".to_string(), + identity: "123".to_string(), + create_at: now.clone(), + }; + assert_eq!(key_repository.create(key).await?, ()); + assert_eq!( + db.into_transaction_log(), + [ + //create + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"INSERT INTO `cluster_key` (`data`, `algorithm`, `identity`, `create_at`) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE `id` = VALUES(`id`)"#, + [vec![].into(), "fake_algorithm".into(), "123".into(), now.clone().into()] + ), + ] + ); + + Ok(()) + } + + #[tokio::test] + async fn test_cluster_key_repository_delete_sql_statement() -> Result<()> { + let now = chrono::Utc::now(); + let db = MockDatabase::new(DatabaseBackend::MySql) + .append_query_results([ + vec![dto::Model { + id: 1, + data: vec![], + algorithm: "fake_algorithm".to_string(), + identity: "123".to_string(), + create_at: now.clone(), + }], + ]).append_exec_results([ + MockExecResult{ + last_insert_id: 1, + rows_affected: 1, + } + ]).into_connection(); + + let key_repository = ClusterKeyRepository::new(&db); + assert_eq!(key_repository.delete_by_id(1).await?, ()); + assert_eq!( + db.into_transaction_log(), + [ + //delete + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"DELETE FROM `cluster_key` WHERE `cluster_key`.`id` = ?"#, + [1i32.into()] + ), + ] + ); + + Ok(()) + } + + #[tokio::test] + async fn test_cluster_key_repository_query_sql_statement() -> Result<()> { + let now = chrono::Utc::now(); + let db = MockDatabase::new(DatabaseBackend::MySql) + .append_query_results([ + vec![dto::Model { + id: 1, + data: vec![], + algorithm: "fake_algorithm".to_string(), + identity: "123".to_string(), + create_at: now.clone(), + }], + vec![dto::Model { + id: 2, + data: vec![], + algorithm: "fake_algorithm".to_string(), + identity: "123".to_string(), + create_at: now.clone(), + }], + ], + ).into_connection(); + + let key_repository = ClusterKeyRepository::new(&db); + assert_eq!( + key_repository.get_latest("fake_algorithm").await?, + Some(ClusterKey::from(dto::Model { + id: 1, + data: vec![], + algorithm: "fake_algorithm".to_string(), + identity: "123".to_string(), + create_at: now.clone(), + })) + ); + assert_eq!( + key_repository.get_by_id(123).await?, + ClusterKey::from(dto::Model { + id: 2, + data: vec![], + algorithm: "fake_algorithm".to_string(), + identity: "123".to_string(), + create_at: now.clone(), + }) + ); + assert_eq!( + db.into_transaction_log(), + [ + //get_latest + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `cluster_key`.`id`, `cluster_key`.`data`, `cluster_key`.`algorithm`, `cluster_key`.`identity`, `cluster_key`.`create_at` FROM `cluster_key` WHERE `cluster_key`.`algorithm` = ? ORDER BY `cluster_key`.`id` DESC LIMIT ?"#, + ["fake_algorithm".into(), 1u64.into()] + ), + //get_by_id + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `cluster_key`.`id`, `cluster_key`.`data`, `cluster_key`.`algorithm`, `cluster_key`.`identity`, `cluster_key`.`create_at` FROM `cluster_key` WHERE `cluster_key`.`id` = ? LIMIT ?"#, + [123i32.into(), 1u64.into()] + ), + ] + ); + + Ok(()) + } +} + diff --git a/src/infra/database/model/token/repository.rs b/src/infra/database/model/token/repository.rs index 2bba298..aaa5efc 100644 --- a/src/infra/database/model/token/repository.rs +++ b/src/infra/database/model/token/repository.rs @@ -26,12 +26,12 @@ use crate::util::key::get_token_hash; #[derive(Clone)] -pub struct TokenRepository { - db_connection: DatabaseConnection +pub struct TokenRepository<'a> { + db_connection: &'a DatabaseConnection } -impl TokenRepository { - pub fn new(db_connection: DatabaseConnection) -> Self { +impl<'a> TokenRepository<'a> { + pub fn new(db_connection: &'a DatabaseConnection) -> Self { Self { db_connection, } @@ -39,7 +39,7 @@ impl TokenRepository { } #[async_trait] -impl Repository for TokenRepository { +impl<'a> Repository for TokenRepository<'a> { async fn create(&self, token: Token) -> Result { let token = token::dto::ActiveModel { user_id: Set(token.user_id), @@ -49,11 +49,11 @@ impl Repository for TokenRepository { expire_at: Set(token.expire_at), ..Default::default() }; - Ok(Token::from(token.insert(&self.db_connection).await?)) + Ok(Token::from(token.insert(self.db_connection).await?)) } async fn get_token_by_id(&self, id: i32) -> Result { - match TokenDTO::find_by_id(id).one(&self.db_connection).await? { + match TokenDTO::find_by_id(id).one(self.db_connection).await? { None => { Err(error::Error::NotFoundError) } @@ -66,7 +66,7 @@ impl Repository for TokenRepository { async fn get_token_by_value(&self, token: &str) -> Result { match TokenDTO::find().filter( token::dto::Column::Token.eq(get_token_hash(token))).one( - &self.db_connection).await? { + self.db_connection).await? { None => { Err(error::Error::NotFoundError) } @@ -74,25 +74,243 @@ impl Repository for TokenRepository { Ok(Token::from(token)) } } - - } async fn delete_by_user_and_id(&self, id: i32, user_id: i32) -> Result<()> { let _ = TokenDTO::delete_many().filter(Condition::all() .add(token::dto::Column::Id.eq(id)) - .add(token::dto::Column::UserId.eq(user_id))).exec(&self.db_connection) + .add(token::dto::Column::UserId.eq(user_id))).exec(self.db_connection) .await?; Ok(()) } async fn get_token_by_user_id(&self, id: i32) -> Result> { let tokens = TokenDTO::find().filter( - token::dto::Column::UserId.eq(id)).all(&self.db_connection).await?; + token::dto::Column::UserId.eq(id)).all(self.db_connection).await?; let mut results = vec![]; for dto in tokens.into_iter() { results.push(Token::from(dto)); } Ok(results) } -} \ No newline at end of file +} + +#[cfg(test)] +mod tests { + use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction}; + use crate::domain::token::entity::Token; + use crate::domain::token::repository::Repository; + use crate::infra::database::model::token::dto; + use crate::util::error::Result; + use crate::infra::database::model::token::repository::{TokenRepository}; + use crate::util::key::get_token_hash; + + #[tokio::test] + async fn test_token_repository_create_sql_statement() -> Result<()> { + let now = chrono::Utc::now(); + let db = MockDatabase::new(DatabaseBackend::MySql) + .append_query_results([ + vec![dto::Model { + id: 1, + user_id: 0, + description: "fake_token".to_string(), + token: "random_number".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }], + ]).append_exec_results([ + MockExecResult{ + last_insert_id: 1, + rows_affected: 1, + } + ]).into_connection(); + + let token_repository = TokenRepository::new(&db); + let user = Token{ + id: 1, + user_id: 0, + description: "fake_token".to_string(), + token: "random_number".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }; + assert_eq!( + token_repository.create(user).await?, + Token::from(dto::Model { + id: 1, + user_id: 0, + description: "fake_token".to_string(), + token: "random_number".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }) + ); + let hashed_token = get_token_hash("random_number"); + assert_eq!( + db.into_transaction_log(), + [ + //create + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"INSERT INTO `token` (`user_id`, `description`, `token`, `create_at`, `expire_at`) VALUES (?, ?, ?, ?, ?)"#, + [0i32.into(), "fake_token".into(), hashed_token.into(), now.clone().into(), now.clone().into()] + ), + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `token`.`id`, `token`.`user_id`, `token`.`description`, `token`.`token`, `token`.`create_at`, `token`.`expire_at` FROM `token` WHERE `token`.`id` = ? LIMIT ?"#, + [1i32.into(), 1u64.into()] + ), + ] + ); + + Ok(()) + } + + #[tokio::test] + async fn test_token_repository_delete_sql_statement() -> Result<()> { + let now = chrono::Utc::now(); + let db = MockDatabase::new(DatabaseBackend::MySql) + .append_query_results([ + vec![dto::Model { + id: 1, + user_id: 0, + description: "fake_token".to_string(), + token: "random_number".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }], + ]).append_exec_results([ + MockExecResult{ + last_insert_id: 1, + rows_affected: 1, + } + ]).into_connection(); + + let token_repository = TokenRepository::new(&db); + assert_eq!(token_repository.delete_by_user_and_id(1, 1).await?, ()); + assert_eq!( + db.into_transaction_log(), + [ + //delete + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"DELETE FROM `token` WHERE `token`.`id` = ? AND `token`.`user_id` = ?"#, + [1i32.into(), 1i32.into()] + ), + ] + ); + + Ok(()) + } + + #[tokio::test] + async fn test_token_repository_query_sql_statement() -> Result<()> { + let now = chrono::Utc::now(); + let db = MockDatabase::new(DatabaseBackend::MySql) + .append_query_results([ + vec![dto::Model { + id: 1, + user_id: 0, + description: "fake_token".to_string(), + token: "random_number".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }], + vec![dto::Model { + id: 2, + user_id: 0, + description: "fake_token2".to_string(), + token: "random_number2".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }], + vec![dto::Model { + id: 1, + user_id: 0, + description: "fake_token".to_string(), + token: "random_number".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }, dto::Model { + id: 2, + user_id: 0, + description: "fake_token2".to_string(), + token: "random_number2".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }], + ]).into_connection(); + + let token_repository = TokenRepository::new(&db); + assert_eq!( + token_repository.get_token_by_id(1).await?, + Token::from(dto::Model { + id: 1, + user_id: 0, + description: "fake_token".to_string(), + token: "random_number".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }) + ); + assert_eq!( + token_repository.get_token_by_value("fake_content").await?, + Token::from(dto::Model { + id: 2, + user_id: 0, + description: "fake_token2".to_string(), + token: "random_number2".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }) + ); + + assert_eq!( + token_repository.get_token_by_user_id(0).await?, + vec![ + Token::from(dto::Model { + id: 1, + user_id: 0, + description: "fake_token".to_string(), + token: "random_number".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }), + Token::from(dto::Model { + id: 2, + user_id: 0, + description: "fake_token2".to_string(), + token: "random_number2".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + })] + ); + + let hashed_token = get_token_hash("fake_content"); + assert_eq!( + db.into_transaction_log(), + [ + //get_token_by_id + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `token`.`id`, `token`.`user_id`, `token`.`description`, `token`.`token`, `token`.`create_at`, `token`.`expire_at` FROM `token` WHERE `token`.`id` = ? LIMIT ?"#, + [1i32.into(), 1u64.into()] + ), + //get_token_by_value + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `token`.`id`, `token`.`user_id`, `token`.`description`, `token`.`token`, `token`.`create_at`, `token`.`expire_at` FROM `token` WHERE `token`.`token` = ? LIMIT ?"#, + [hashed_token.into(), 1u64.into()] + ), + //get_token_by_user_id + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `token`.`id`, `token`.`user_id`, `token`.`description`, `token`.`token`, `token`.`create_at`, `token`.`expire_at` FROM `token` WHERE `token`.`user_id` = ?"#, + [0i32.into()] + ), + ] + ); + + Ok(()) + } +} diff --git a/src/infra/database/model/user/repository.rs b/src/infra/database/model/user/repository.rs index d476dd9..f6ecaf8 100644 --- a/src/infra/database/model/user/repository.rs +++ b/src/infra/database/model/user/repository.rs @@ -23,12 +23,12 @@ use crate::util::error::{Error, Result}; use async_trait::async_trait; #[derive(Clone)] -pub struct UserRepository { - db_connection: DatabaseConnection +pub struct UserRepository<'a> { + db_connection: &'a DatabaseConnection } -impl UserRepository { - pub fn new(db_connection: DatabaseConnection) -> Self { +impl<'a> UserRepository<'a> { + pub fn new(db_connection: &'a DatabaseConnection) -> Self { Self { db_connection, } @@ -36,7 +36,7 @@ impl UserRepository { } #[async_trait] -impl Repository for UserRepository { +impl<'a> Repository for UserRepository<'a> { async fn create(&self, user: User) -> Result { return match self.get_by_email(&user.email).await { @@ -48,14 +48,14 @@ impl Repository for UserRepository { email: Set(user.email), ..Default::default() }; - Ok(User::from(user.insert(&self.db_connection).await?)) + Ok(User::from(user.insert(self.db_connection).await?)) } } } async fn get_by_id(&self, id: i32) -> Result { match UserDTO::find_by_id(id).one( - &self.db_connection).await? { + self.db_connection).await? { None => { Err(Error::NotFoundError) } @@ -68,7 +68,7 @@ impl Repository for UserRepository { async fn get_by_email(&self, email: &str) -> Result { match UserDTO::find().filter( user::dto::Column::Email.eq(email)).one( - &self.db_connection).await? { + self.db_connection).await? { None => { Err(Error::NotFoundError) } @@ -79,8 +79,154 @@ impl Repository for UserRepository { } async fn delete_by_id(&self, id: i32) -> Result<()> { - let _ = UserDTO::delete_by_id(id).exec(&self.db_connection) + let _ = UserDTO::delete_by_id(id).exec(self.db_connection) .await?; Ok(()) } } + +#[cfg(test)] +mod tests { + use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction}; + use crate::domain::user::entity::User; + use crate::domain::user::repository::Repository; + use crate::infra::database::model::user::dto; + use crate::util::error::Result; + use crate::infra::database::model::user::repository::UserRepository; + + #[tokio::test] + async fn test_user_repository_query_sql_statement() -> Result<()> { + let db = MockDatabase::new(DatabaseBackend::MySql) + .append_query_results([ + vec![dto::Model { + id: 1, + email: "fake_email".to_string(), + }], + vec![dto::Model { + id: 2, + email: "fake_email".to_string(), + }], + ]).into_connection(); + + let user_repository = UserRepository::new(&db); + assert_eq!( + user_repository.get_by_email("fake_email").await?, + User::from(dto::Model { + id: 1, + email: "fake_email".to_string(), + }) + ); + + assert_eq!( + user_repository.get_by_id(1).await?, + User::from(dto::Model { + id: 2, + email: "fake_email".to_string(), + }) + ); + + assert_eq!( + db.into_transaction_log(), + [ + //get_by_email + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `user`.`id`, `user`.`email` FROM `user` WHERE `user`.`email` = ? LIMIT ?"#, + ["fake_email".into(), 1u64.into()] + ), + //get_by_id + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `user`.`id`, `user`.`email` FROM `user` WHERE `user`.`id` = ? LIMIT ?"#, + [1i32.into(), 1u64.into()] + ), + ] + ); + + Ok(()) + } + + #[tokio::test] + async fn test_user_repository_create_sql_statement() -> Result<()> { + let db = MockDatabase::new(DatabaseBackend::MySql) + .append_query_results([ + vec![], + vec![dto::Model { + id: 3, + email: "fake_email".to_string(), + }], + ]).append_exec_results([ + MockExecResult{ + last_insert_id: 3, + rows_affected: 1, + } + ]).into_connection(); + + let user_repository = UserRepository::new(&db); + let user = User{ + id: 0, + email: "fake_string".to_string(), + }; + assert_eq!( + user_repository.create(user).await?, + User::from(dto::Model { + id: 3, + email: "fake_email".to_string(), + }) + ); + assert_eq!( + db.into_transaction_log(), + [ + //create + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `user`.`id`, `user`.`email` FROM `user` WHERE `user`.`email` = ? LIMIT ?"#, + ["fake_string".into(), 1u64.into()] + ), + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"INSERT INTO `user` (`email`) VALUES (?)"#, + ["fake_string".into()] + ), + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `user`.`id`, `user`.`email` FROM `user` WHERE `user`.`id` = ? LIMIT ?"#, + [3i32.into(), 1u64.into()] + ), + ] + ); + + Ok(()) + } + #[tokio::test] + async fn test_user_repository_delete_sql_statement() -> Result<()> { + let db = MockDatabase::new(DatabaseBackend::MySql) + .append_query_results([ + vec![dto::Model { + id: 1, + email: "fake_email".to_string(), + }], + ]).append_exec_results([ + MockExecResult{ + last_insert_id: 1, + rows_affected: 1, + } + ]).into_connection(); + + let user_repository = UserRepository::new(&db); + assert_eq!(user_repository.delete_by_id(1).await?, ()); + assert_eq!( + db.into_transaction_log(), + [ + //delete + Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"DELETE FROM `user` WHERE `user`.`id` = ?"#, + [1i32.into()] + ), + ] + ); + + Ok(()) + } +} diff --git a/src/infra/database/pool.rs b/src/infra/database/pool.rs index bbddf63..51b0344 100644 --- a/src/infra/database/pool.rs +++ b/src/infra/database/pool.rs @@ -86,12 +86,12 @@ pub fn get_db_pool() -> Result { }; } -pub fn get_db_connection() -> Result { +pub fn get_db_connection() -> Result<&'static DatabaseConnection> { return match DB_CONNECTION.get() { None => Err(Error::DatabaseError( "failed to get database pool".to_string(), )), - Some(pool) => Ok(pool.clone()), + Some(pool) => Ok(pool), }; } diff --git a/src/infra/sign_backend/factory.rs b/src/infra/sign_backend/factory.rs index a422d2a..9804187 100644 --- a/src/infra/sign_backend/factory.rs +++ b/src/infra/sign_backend/factory.rs @@ -25,7 +25,7 @@ use crate::infra::sign_backend::memory::backend::MemorySignBackend; pub struct SignBackendFactory {} impl SignBackendFactory { - pub async fn new_engine(config: Arc>, db_connection: DatabaseConnection) -> Result> { + pub async fn new_engine(config: Arc>, db_connection: &'static DatabaseConnection) -> Result> { let engine_type = SignBackendType::from_str( config.read()?.get_string("sign-backend.type")?.as_str(), )?; diff --git a/src/infra/sign_backend/memory/backend.rs b/src/infra/sign_backend/memory/backend.rs index 5322465..176a8c5 100644 --- a/src/infra/sign_backend/memory/backend.rs +++ b/src/infra/sign_backend/memory/backend.rs @@ -50,7 +50,7 @@ impl MemorySignBackend { /// 2. initialize the cluster repo /// 2. initialize the encryption engine including the cluster key /// 3. initialize the signing plugins - pub async fn new(server_config: Arc>, db_connection: DatabaseConnection) -> Result { + pub async fn new(server_config: Arc>, db_connection: &'static DatabaseConnection) -> Result { //initialize the kms backend let kms_provider = factory::KMSProviderFactory::new_provider( &server_config.read()?.get_table("memory.kms-provider")? -- Gitee