From 3efdd068304be4674b969765cb072cc04df2a4ae Mon Sep 17 00:00:00 2001 From: TommyLike Date: Sun, 8 Oct 2023 15:46:01 +0800 Subject: [PATCH] Fix cargo fmt issue --- build.rs | 6 +- src/application/datakey.rs | 462 ++++-- src/application/mod.rs | 2 +- src/application/user.rs | 138 +- src/client/cmd/add.rs | 228 +-- src/client/cmd/mod.rs | 2 +- src/client/cmd/traits.rs | 11 +- src/client/file_handler/efi.rs | 26 +- src/client/file_handler/factory.rs | 25 +- src/client/file_handler/generic.rs | 10 +- src/client/file_handler/kernel_module.rs | 95 +- src/client/file_handler/mod.rs | 6 +- src/client/file_handler/rpm.rs | 206 +-- src/client/file_handler/traits.rs | 4 +- src/client/load_balancer/dns.rs | 32 +- src/client/load_balancer/factory.rs | 95 +- src/client/load_balancer/mod.rs | 4 +- src/client/load_balancer/single.rs | 21 +- src/client/load_balancer/traits.rs | 4 +- src/client/mod.rs | 4 +- src/client/sign_identity.rs | 13 +- src/client/worker/assembler.rs | 45 +- src/client/worker/key_fetcher.rs | 21 +- src/client/worker/mod.rs | 4 +- src/client/worker/signer.rs | 38 +- src/client/worker/splitter.rs | 26 +- src/client/worker/traits.rs | 8 +- src/client_entrypoint.rs | 46 +- src/control_admin_entrypoint.rs | 113 +- src/control_server_entrypoint.rs | 14 +- src/data_server_entrypoint.rs | 9 +- src/domain/clusterkey/entity.rs | 29 +- src/domain/datakey/entity.rs | 68 +- src/domain/datakey/mod.rs | 2 +- src/domain/datakey/plugins/mod.rs | 1 - src/domain/datakey/plugins/openpgp.rs | 40 +- src/domain/datakey/plugins/x509.rs | 24 +- src/domain/datakey/repository.rs | 41 +- src/domain/encryption_engine.rs | 4 +- src/domain/encryptor.rs | 4 +- src/domain/kms_provider.rs | 6 +- src/domain/mod.rs | 4 +- src/domain/sign_plugin.rs | 29 +- src/domain/sign_service.rs | 33 +- src/domain/token/entity.rs | 12 +- src/domain/token/mod.rs | 2 +- src/domain/token/repository.rs | 2 +- src/domain/user/entity.rs | 17 +- src/domain/user/mod.rs | 2 +- src/infra/database/model/clusterkey/dto.rs | 9 +- .../database/model/clusterkey/repository.rs | 133 +- src/infra/database/model/datakey/dto.rs | 35 +- .../database/model/datakey/repository.rs | 1315 +++++++++++------ src/infra/database/model/mod.rs | 6 +- .../database/model/request_delete/dto.rs | 9 +- src/infra/database/model/token/dto.rs | 5 +- src/infra/database/model/token/repository.rs | 189 +-- src/infra/database/model/user/dto.rs | 10 +- src/infra/database/model/user/repository.rs | 82 +- .../database/model/x509_crl_content/dto.rs | 17 +- .../database/model/x509_revoked_key/dto.rs | 7 +- src/infra/database/pool.rs | 27 +- src/infra/encryption/algorithm/aes.rs | 76 +- src/infra/encryption/algorithm/factory.rs | 8 +- src/infra/encryption/dummy_engine.rs | 3 +- src/infra/encryption/engine.rs | 86 +- src/infra/encryption/mod.rs | 2 +- src/infra/kms/dummy.rs | 9 +- src/infra/kms/factory.rs | 11 +- src/infra/kms/huaweicloud.rs | 139 +- src/infra/kms/mod.rs | 2 +- src/infra/mod.rs | 4 +- src/infra/sign_backend/factory.rs | 24 +- src/infra/sign_backend/memory/backend.rs | 69 +- src/infra/sign_backend/memory/mod.rs | 2 +- src/infra/sign_plugin/mod.rs | 2 +- src/infra/sign_plugin/openpgp.rs | 345 +++-- src/infra/sign_plugin/signers.rs | 17 +- src/infra/sign_plugin/util.rs | 27 +- src/infra/sign_plugin/x509.rs | 740 +++++++--- .../handler/control/datakey_handler.rs | 236 ++- .../handler/control/health_handler.rs | 7 +- src/presentation/handler/control/mod.rs | 4 +- .../handler/control/model/datakey/dto.rs | 124 +- .../handler/control/model/datakey/mod.rs | 2 +- src/presentation/handler/control/model/mod.rs | 2 +- .../handler/control/model/token/dto.rs | 11 +- .../handler/control/model/token/mod.rs | 2 +- .../handler/control/model/user/dto.rs | 138 +- .../handler/control/model/user/mod.rs | 2 +- .../handler/control/user_handler.rs | 94 +- .../handler/data/health_handler.rs | 12 +- src/presentation/handler/data/mod.rs | 2 +- src/presentation/handler/data/sign_handler.rs | 109 +- src/presentation/mod.rs | 2 +- src/presentation/server/control_server.rs | 150 +- src/presentation/server/data_server.rs | 73 +- src/util/cache.rs | 142 +- src/util/config.rs | 13 +- src/util/error.rs | 113 +- src/util/key.rs | 84 +- src/util/mod.rs | 4 +- src/util/options.rs | 2 +- src/util/sign.rs | 2 +- 104 files changed, 4326 insertions(+), 2452 deletions(-) diff --git a/build.rs b/build.rs index f2af3e3..27b5315 100644 --- a/build.rs +++ b/build.rs @@ -1,6 +1,8 @@ fn main() { let sign_proto = "./proto/signatrust.proto"; let health_proto = "./proto/health.proto"; - tonic_build::configure().protoc_arg("--experimental_allow_proto3_optional" - ).compile(&[sign_proto, health_proto], &["proto"]).unwrap(); + tonic_build::configure() + .protoc_arg("--experimental_allow_proto3_optional") + .compile(&[sign_proto, health_proto], &["proto"]) + .unwrap(); } diff --git a/src/application/datakey.rs b/src/application/datakey.rs index 9f31041..3edd2fc 100644 --- a/src/application/datakey.rs +++ b/src/application/datakey.rs @@ -14,24 +14,27 @@ * */ +use crate::domain::datakey::entity::{ + DataKey, DatakeyPaginationQuery, KeyAction, KeyState, KeyType, PagedDatakey, Visibility, + X509RevokeReason, X509CRL, +}; use crate::domain::datakey::repository::Repository as DatakeyRepository; use crate::domain::sign_service::SignBackend; use crate::util::error::{Error, Result}; use async_trait::async_trait; -use crate::domain::datakey::entity::{DataKey, DatakeyPaginationQuery, KeyAction, KeyState, KeyType, PagedDatakey, Visibility, X509CRL, X509RevokeReason}; use tokio::time::{self}; +use crate::domain::datakey::entity::KeyType::{OpenPGP, X509CA, X509EE, X509ICA}; +use crate::presentation::handler::control::model::user::dto::UserIdentity; use crate::util::cache::TimedFixedSizeCache; -use std::collections::HashMap; -use std::sync::{Arc}; use chrono::{Duration, Utc}; +use std::collections::HashMap; +use std::sync::Arc; use tokio::sync::RwLock; use tokio_util::sync::CancellationToken; -use crate::domain::datakey::entity::KeyType::{OpenPGP, X509CA, X509EE, X509ICA}; -use crate::presentation::handler::control::model::user::dto::UserIdentity; #[async_trait] -pub trait KeyService: Send + Sync{ +pub trait KeyService: Send + Sync { async fn create(&self, user: UserIdentity, data: &mut DataKey) -> Result; async fn import(&self, data: &mut DataKey) -> Result; async fn get_raw_key_by_name(&self, name: &str) -> Result; @@ -40,62 +43,91 @@ pub trait KeyService: Send + Sync{ async fn get_one(&self, user: Option, id_or_name: String) -> Result; //get keys content async fn export_one(&self, user: Option, id_or_name: String) -> Result; - async fn export_cert_crl(&self, user: Option,id_or_name: String) -> Result; + async fn export_cert_crl( + &self, + user: Option, + id_or_name: String, + ) -> Result; //keys related operation async fn request_delete(&self, user: UserIdentity, id_or_name: String) -> Result<()>; async fn cancel_delete(&self, user: UserIdentity, id_or_name: String) -> Result<()>; - async fn request_revoke(&self, user: UserIdentity, id_or_name: String, reason: X509RevokeReason) -> Result<()>; + async fn request_revoke( + &self, + user: UserIdentity, + id_or_name: String, + reason: X509RevokeReason, + ) -> Result<()>; async fn cancel_revoke(&self, user: UserIdentity, id_or_name: String) -> Result<()>; async fn enable(&self, user: Option, id_or_name: String) -> Result<()>; async fn disable(&self, user: Option, id_or_name: String) -> Result<()>; //used for data server - async fn sign(&self, key_type: String, key_name: String, options: &HashMap, data: Vec) ->Result>; - async fn get_by_type_and_name(&self, key_type: String, key_name: String) ->Result; + async fn sign( + &self, + key_type: String, + key_name: String, + options: &HashMap, + data: Vec, + ) -> Result>; + async fn get_by_type_and_name(&self, key_type: String, key_name: String) -> Result; //method below used for maintenance fn start_key_rotate_loop(&self, cancel_token: CancellationToken) -> Result<()>; //method below used for x509 crl - fn start_key_plugin_maintenance(&self, cancel_token: CancellationToken, refresh_days: i32) -> Result<()>; + fn start_key_plugin_maintenance( + &self, + cancel_token: CancellationToken, + refresh_days: i32, + ) -> Result<()>; } - - pub struct DBKeyService where R: DatakeyRepository + Clone + 'static, - S: SignBackend + ?Sized + 'static + S: SignBackend + ?Sized + 'static, { repository: R, sign_service: Arc>>, - container: TimedFixedSizeCache + container: TimedFixedSizeCache, } impl DBKeyService - where - R: DatakeyRepository + Clone + 'static, - S: SignBackend + ?Sized + 'static +where + R: DatakeyRepository + Clone + 'static, + S: SignBackend + ?Sized + 'static, { pub fn new(repository: R, sign_service: Box) -> Self { Self { repository, sign_service: Arc::new(RwLock::new(sign_service)), - container: TimedFixedSizeCache::new(Some(100), None, None, None) + container: TimedFixedSizeCache::new(Some(100), None, None, None), } } - async fn get_and_check_permission(&self, user: Option, id_or_name: String, action: KeyAction, raw_key: bool) -> Result { + async fn get_and_check_permission( + &self, + user: Option, + id_or_name: String, + action: KeyAction, + raw_key: bool, + ) -> Result { let id = id_or_name.parse::(); let data_key: DataKey = match id { Ok(id) => { - self.repository.get_by_id_or_name(Some(id), None, raw_key).await? + self.repository + .get_by_id_or_name(Some(id), None, raw_key) + .await? } Err(_) => { - self.repository.get_by_id_or_name(None, Some(id_or_name), raw_key).await? + self.repository + .get_by_id_or_name(None, Some(id_or_name), raw_key) + .await? } }; //check permission for private keys - if data_key.visibility == Visibility::Private && (user.is_none() || data_key.user != user.unwrap().id) { + if data_key.visibility == Visibility::Private + && (user.is_none() || data_key.user != user.unwrap().id) + { return Err(Error::UnprivilegedError); } self.validate_type_and_state(&data_key, action)?; @@ -104,71 +136,185 @@ impl DBKeyService fn validate_type_and_state(&self, key: &DataKey, key_action: KeyAction) -> Result<()> { let valid_action_by_key_type = HashMap::from([ - (OpenPGP, vec![KeyAction::Delete, KeyAction::CancelDelete, KeyAction::Disable, KeyAction::Enable, KeyAction::Sign, KeyAction::Read]), - (X509CA, vec![KeyAction::Delete, KeyAction::CancelDelete, KeyAction::Disable, KeyAction::Enable, KeyAction::IssueCert, KeyAction::Read]), - (X509ICA, vec![KeyAction::Delete, KeyAction::CancelDelete, KeyAction::Revoke, KeyAction::CancelRevoke, KeyAction::Disable, KeyAction::Enable, KeyAction::Read, KeyAction::IssueCert]), - (X509EE, vec![KeyAction::Delete, KeyAction::CancelDelete, KeyAction::Revoke, KeyAction::CancelRevoke, KeyAction::Disable, KeyAction::Enable, KeyAction::Read, KeyAction::Sign]), + ( + OpenPGP, + vec![ + KeyAction::Delete, + KeyAction::CancelDelete, + KeyAction::Disable, + KeyAction::Enable, + KeyAction::Sign, + KeyAction::Read, + ], + ), + ( + X509CA, + vec![ + KeyAction::Delete, + KeyAction::CancelDelete, + KeyAction::Disable, + KeyAction::Enable, + KeyAction::IssueCert, + KeyAction::Read, + ], + ), + ( + X509ICA, + vec![ + KeyAction::Delete, + KeyAction::CancelDelete, + KeyAction::Revoke, + KeyAction::CancelRevoke, + KeyAction::Disable, + KeyAction::Enable, + KeyAction::Read, + KeyAction::IssueCert, + ], + ), + ( + X509EE, + vec![ + KeyAction::Delete, + KeyAction::CancelDelete, + KeyAction::Revoke, + KeyAction::CancelRevoke, + KeyAction::Disable, + KeyAction::Enable, + KeyAction::Read, + KeyAction::Sign, + ], + ), ]); let valid_state_by_key_action = HashMap::from([ - (KeyAction::Delete, vec![KeyState::Disabled, KeyState::Revoked, KeyState::PendingDelete]), + ( + KeyAction::Delete, + vec![ + KeyState::Disabled, + KeyState::Revoked, + KeyState::PendingDelete, + ], + ), (KeyAction::CancelDelete, vec![KeyState::PendingDelete]), - (KeyAction::Revoke, vec![KeyState::Disabled, KeyState::PendingRevoke]), + ( + KeyAction::Revoke, + vec![KeyState::Disabled, KeyState::PendingRevoke], + ), (KeyAction::CancelRevoke, vec![KeyState::PendingRevoke]), (KeyAction::Enable, vec![KeyState::Disabled]), (KeyAction::Disable, vec![KeyState::Enabled]), - (KeyAction::Sign, vec![KeyState::Enabled, KeyState::PendingDelete, KeyState::PendingRevoke]), - (KeyAction::IssueCert, vec![KeyState::Enabled, KeyState::PendingDelete, KeyState::PendingRevoke]), - (KeyAction::Read, vec![KeyState::Enabled, KeyState::PendingDelete, KeyState::PendingRevoke, KeyState::Disabled]), + ( + KeyAction::Sign, + vec![ + KeyState::Enabled, + KeyState::PendingDelete, + KeyState::PendingRevoke, + ], + ), + ( + KeyAction::IssueCert, + vec![ + KeyState::Enabled, + KeyState::PendingDelete, + KeyState::PendingRevoke, + ], + ), + ( + KeyAction::Read, + vec![ + KeyState::Enabled, + KeyState::PendingDelete, + KeyState::PendingRevoke, + KeyState::Disabled, + ], + ), ]); match valid_action_by_key_type.get(&key.key_type) { None => { - return Err(Error::ConfigError("key type is missing, please check the key type".to_string())); + return Err(Error::ConfigError( + "key type is missing, please check the key type".to_string(), + )); } Some(actions) => { if !actions.contains(&key_action) { - return Err(Error::ActionsNotAllowedError(format!("action '{}' is not permitted for key type '{}'", key_action, key.key_type))); + return Err(Error::ActionsNotAllowedError(format!( + "action '{}' is not permitted for key type '{}'", + key_action, key.key_type + ))); } } } match valid_state_by_key_action.get(&key_action) { None => { - return Err(Error::ConfigError("key action is missing, please check the key action".to_string())) + return Err(Error::ConfigError( + "key action is missing, please check the key action".to_string(), + )) } Some(states) => { if !states.contains(&key.key_state) { - return Err(Error::ActionsNotAllowedError(format!("action '{}' is not permitted for state '{}'", key_action, key.key_state))) + return Err(Error::ActionsNotAllowedError(format!( + "action '{}' is not permitted for state '{}'", + key_action, key.key_state + ))); } } } - if (key_action == KeyAction::Revoke || key_action == KeyAction::CancelRevoke) && key.parent_id.is_none() { - return Err(Error::ActionsNotAllowedError(format!("action '{}' is not permitted for key without parent", key_action))) + if (key_action == KeyAction::Revoke || key_action == KeyAction::CancelRevoke) + && key.parent_id.is_none() + { + return Err(Error::ActionsNotAllowedError(format!( + "action '{}' is not permitted for key without parent", + key_action + ))); } Ok(()) } - async fn check_key_hierarchy(&self, user: UserIdentity, data: &DataKey, parent_id: i32) -> Result<()> { - let parent_key = self.repository.get_by_id_or_name(Some(parent_id), None, true).await?; + async fn check_key_hierarchy( + &self, + user: UserIdentity, + data: &DataKey, + parent_id: i32, + ) -> Result<()> { + let parent_key = self + .repository + .get_by_id_or_name(Some(parent_id), None, true) + .await?; //check permission for private keys if parent_key.visibility == Visibility::Private && parent_key.user != user.id { return Err(Error::UnprivilegedError); } if parent_key.visibility != data.visibility { - return Err(Error::ActionsNotAllowedError(format!("parent key '{}' visibility not equal to current datakey", parent_key.name))); + return Err(Error::ActionsNotAllowedError(format!( + "parent key '{}' visibility not equal to current datakey", + parent_key.name + ))); } if parent_key.key_state != KeyState::Enabled { - return Err(Error::ActionsNotAllowedError(format!("parent key '{}' not in enable state", parent_key.name))); + return Err(Error::ActionsNotAllowedError(format!( + "parent key '{}' not in enable state", + parent_key.name + ))); } if parent_key.expire_at < data.expire_at { - return Err(Error::ActionsNotAllowedError(format!("parent key '{}' expire time is less than child key", parent_key.name))); + return Err(Error::ActionsNotAllowedError(format!( + "parent key '{}' expire time is less than child key", + parent_key.name + ))); } if data.key_type == X509ICA && parent_key.key_type != X509CA { - return Err(Error::ActionsNotAllowedError("only CA key is allowed for creating ICA".to_string())); + return Err(Error::ActionsNotAllowedError( + "only CA key is allowed for creating ICA".to_string(), + )); } if data.key_type == X509EE && parent_key.key_type != X509ICA { - return Err(Error::ActionsNotAllowedError("only ICA key is allowed for creating End Entity Key".to_string())); + return Err(Error::ActionsNotAllowedError( + "only ICA key is allowed for creating End Entity Key".to_string(), + )); } if data.key_type == X509CA || data.key_type == OpenPGP { - return Err(Error::ActionsNotAllowedError("CA key or openPGP is not allowed to specify parent key".to_string())); + return Err(Error::ActionsNotAllowedError( + "CA key or openPGP is not allowed to specify parent key".to_string(), + )); } Ok(()) } @@ -178,7 +324,7 @@ impl DBKeyService impl KeyService for DBKeyService where R: DatakeyRepository + Clone + 'static, - S: SignBackend + ?Sized + 'static + S: SignBackend + ?Sized + 'static, { async fn create(&self, user: UserIdentity, data: &mut DataKey) -> Result { //check parent key is enabled,expire time is greater than child key and hierarchy is correct @@ -186,8 +332,16 @@ where self.check_key_hierarchy(user, data, parent_id).await?; } //check datakey existence - if self.repository.get_by_id_or_name(None, Some(data.name.clone()), true).await.is_ok() { - return Err(Error::ParameterError(format!("datakey '{}' already exists", data.name))); + if self + .repository + .get_by_id_or_name(None, Some(data.name.clone()), true) + .await + .is_ok() + { + return Err(Error::ParameterError(format!( + "datakey '{}' already exists", + data.name + ))); } //we need to create a key in database first, then generate sensitive data let mut key = self.repository.create(data.clone()).await?; @@ -204,37 +358,58 @@ where } async fn import(&self, data: &mut DataKey) -> Result { - self.sign_service.read().await.validate_and_update(data).await?; + self.sign_service + .read() + .await + .validate_and_update(data) + .await?; self.repository.create(data.clone()).await } async fn get_raw_key_by_name(&self, name: &str) -> Result { - self.repository.get_by_id_or_name(None, Some(name.to_owned()), true).await + self.repository + .get_by_id_or_name(None, Some(name.to_owned()), true) + .await } - async fn get_all(&self, user_id: i32, query: DatakeyPaginationQuery) -> Result { - self.repository.get_all_keys( user_id, query).await + async fn get_all(&self, user_id: i32, query: DatakeyPaginationQuery) -> Result { + self.repository.get_all_keys(user_id, query).await } - async fn get_one(&self, user: Option, id_or_name: String) -> Result { - let datakey = self.get_and_check_permission(user, id_or_name, KeyAction::Read, false).await?; + async fn get_one(&self, user: Option, id_or_name: String) -> Result { + let datakey = self + .get_and_check_permission(user, id_or_name, KeyAction::Read, false) + .await?; Ok(datakey) - } async fn export_one(&self, user: Option, id_or_name: String) -> Result { //NOTE: since the public key or certificate basically will not change at all, we will cache the key here. if let Some(datakey) = self.container.get_read_datakey(&id_or_name).await { - return Ok(datakey) + return Ok(datakey); } - let mut key = self.get_and_check_permission(user, id_or_name.clone(), KeyAction::Read, true).await?; - self.sign_service.read().await.decode_public_keys(&mut key).await?; - self.container.update_read_datakey(&id_or_name, key.clone()).await?; + let mut key = self + .get_and_check_permission(user, id_or_name.clone(), KeyAction::Read, true) + .await?; + self.sign_service + .read() + .await + .decode_public_keys(&mut key) + .await?; + self.container + .update_read_datakey(&id_or_name, key.clone()) + .await?; Ok(key) } - async fn export_cert_crl(&self, user: Option, id_or_name: String) -> Result { - let key = self.get_and_check_permission(user, id_or_name, KeyAction::Read, true).await?; + async fn export_cert_crl( + &self, + user: Option, + id_or_name: String, + ) -> Result { + let key = self + .get_and_check_permission(user, id_or_name, KeyAction::Read, true) + .await?; let crl = self.repository.get_x509_crl_by_ca_id(key.id).await?; Ok(crl) } @@ -242,59 +417,116 @@ where async fn request_delete(&self, user: UserIdentity, id_or_name: String) -> Result<()> { let user_id = user.id; let user_email = user.email.clone(); - let key = self.get_and_check_permission(Some(user), id_or_name, KeyAction::Delete, true).await?; + let key = self + .get_and_check_permission(Some(user), id_or_name, KeyAction::Delete, true) + .await?; //check if the ca/ica key is used by other keys if key.key_type == KeyType::X509ICA || key.key_type == KeyType::X509CA { let children = self.repository.get_by_parent_id(key.id).await?; if !children.is_empty() { - return Err(Error::ActionsNotAllowedError(format!("key '{}' is used by other keys, request delete is not allowed", key.name))); + return Err(Error::ActionsNotAllowedError(format!( + "key '{}' is used by other keys, request delete is not allowed", + key.name + ))); } } - self.repository.request_delete_key(user_id, user_email, key.id, key.visibility == Visibility::Public).await + self.repository + .request_delete_key( + user_id, + user_email, + key.id, + key.visibility == Visibility::Public, + ) + .await } async fn cancel_delete(&self, user: UserIdentity, id_or_name: String) -> Result<()> { let user_id = user.id; - let key = self.get_and_check_permission(Some(user), id_or_name, KeyAction::CancelDelete, true).await?; + let key = self + .get_and_check_permission(Some(user), id_or_name, KeyAction::CancelDelete, true) + .await?; self.repository.cancel_delete_key(user_id, key.id).await } - async fn request_revoke(&self, user: UserIdentity, id_or_name: String, reason: X509RevokeReason) -> Result<()> { + async fn request_revoke( + &self, + user: UserIdentity, + id_or_name: String, + reason: X509RevokeReason, + ) -> Result<()> { let user_id = user.id; let user_email = user.email.clone(); - let key = self.get_and_check_permission(Some(user), id_or_name, KeyAction::Revoke, true).await?; - self.repository.request_revoke_key(user_id, user_email, key.id, key.parent_id.unwrap(), reason, key.visibility == Visibility::Public).await?; + let key = self + .get_and_check_permission(Some(user), id_or_name, KeyAction::Revoke, true) + .await?; + self.repository + .request_revoke_key( + user_id, + user_email, + key.id, + key.parent_id.unwrap(), + reason, + key.visibility == Visibility::Public, + ) + .await?; Ok(()) } async fn cancel_revoke(&self, user: UserIdentity, id_or_name: String) -> Result<()> { let user_id = user.id; - let key = self.get_and_check_permission(Some(user), id_or_name, KeyAction::CancelRevoke, true).await?; - self.repository.cancel_revoke_key(user_id, key.id, key.parent_id.unwrap()).await?; + let key = self + .get_and_check_permission(Some(user), id_or_name, KeyAction::CancelRevoke, true) + .await?; + self.repository + .cancel_revoke_key(user_id, key.id, key.parent_id.unwrap()) + .await?; Ok(()) } async fn enable(&self, user: Option, id_or_name: String) -> Result<()> { - let key = self.get_and_check_permission(user, id_or_name, KeyAction::Enable, true).await?; - self.repository.update_state(key.id, KeyState::Enabled).await + let key = self + .get_and_check_permission(user, id_or_name, KeyAction::Enable, true) + .await?; + self.repository + .update_state(key.id, KeyState::Enabled) + .await } async fn disable(&self, user: Option, id_or_name: String) -> Result<()> { - let key = self.get_and_check_permission(user, id_or_name, KeyAction::Disable, true).await?; - self.repository.update_state(key.id, KeyState::Disabled).await + let key = self + .get_and_check_permission(user, id_or_name, KeyAction::Disable, true) + .await?; + self.repository + .update_state(key.id, KeyState::Disabled) + .await } - async fn sign(&self, key_type: String, key_name: String, options: &HashMap, data: Vec) -> Result> { + async fn sign( + &self, + key_type: String, + key_name: String, + options: &HashMap, + data: Vec, + ) -> Result> { let datakey = self.get_by_type_and_name(key_type, key_name).await?; - self.sign_service.read().await.sign(&datakey, data, options.clone()).await + self.sign_service + .read() + .await + .sign(&datakey, data, options.clone()) + .await } async fn get_by_type_and_name(&self, key_type: String, key_name: String) -> Result { if let Some(datakey) = self.container.get_sign_datakey(&key_name).await { - return Ok(datakey) + return Ok(datakey); } - let key = self.repository.get_enabled_key_by_type_and_name_with_parent_key(key_type, key_name.clone()).await?; - self.container.update_sign_datakey(&key_name, key.clone()).await?; + let key = self + .repository + .get_enabled_key_by_type_and_name_with_parent_key(key_type, key_name.clone()) + .await?; + self.container + .update_sign_datakey(&key_name, key.clone()) + .await?; Ok(key) } fn start_key_rotate_loop(&self, cancel_token: CancellationToken) -> Result<()> { @@ -322,47 +554,53 @@ where } } } - }); Ok(()) } - fn start_key_plugin_maintenance(&self, cancel_token: CancellationToken, refresh_days: i32) -> Result<()> { + fn start_key_plugin_maintenance( + &self, + cancel_token: CancellationToken, + refresh_days: i32, + ) -> Result<()> { let mut interval = time::interval(Duration::hours(2).to_std()?); let duration = Duration::days(refresh_days as i64); let repository = self.repository.clone(); let sign_service = self.sign_service.clone(); tokio::spawn(async move { - loop { tokio::select! { - _ = interval.tick() => { - info!("start to update execute key plugin maintenance"); - match repository.get_keys_for_crl_update(duration).await { - Ok(keys) => { - let now = Utc::now(); - for key in keys { - match repository.get_revoked_serial_number_by_parent_id(key.id).await { - Ok(revoke_keys) => { - match sign_service.read().await.generate_crl_content(&key, revoke_keys, now, now + duration).await { - Ok(data) => { - let crl_content = X509CRL::new(key.id, data, now, now); - if let Err(e) = repository.upsert_x509_crl(crl_content).await { - error!("Failed to update CRL content for key: {} {}, {}", key.key_state, key.id, e); - } else { - info!("CRL has been successfully updated for key: {} {}", key.key_type, key.id); - }} - Err(e) => { - error!("failed to update CRL content for key: {} {} and error {}", key.key_state, key.id, e); - }}} - Err(e) => { - error!("failed to get revoked keys for key {} {}, error {}", key.key_state, key.id, e); - }}}} - Err(e) => { - error!("failed to get keys for CRL update: {}", e); - }}} - _ = cancel_token.cancelled() => { - info!("cancel token received, will quit key plugin maintenance loop"); - break; - }}}}); + loop { + tokio::select! { + _ = interval.tick() => { + info!("start to update execute key plugin maintenance"); + match repository.get_keys_for_crl_update(duration).await { + Ok(keys) => { + let now = Utc::now(); + for key in keys { + match repository.get_revoked_serial_number_by_parent_id(key.id).await { + Ok(revoke_keys) => { + match sign_service.read().await.generate_crl_content(&key, revoke_keys, now, now + duration).await { + Ok(data) => { + let crl_content = X509CRL::new(key.id, data, now, now); + if let Err(e) = repository.upsert_x509_crl(crl_content).await { + error!("Failed to update CRL content for key: {} {}, {}", key.key_state, key.id, e); + } else { + info!("CRL has been successfully updated for key: {} {}", key.key_type, key.id); + }} + Err(e) => { + error!("failed to update CRL content for key: {} {} and error {}", key.key_state, key.id, e); + }}} + Err(e) => { + error!("failed to get revoked keys for key {} {}, error {}", key.key_state, key.id, e); + }}}} + Err(e) => { + error!("failed to get keys for CRL update: {}", e); + }}} + _ = cancel_token.cancelled() => { + info!("cancel token received, will quit key plugin maintenance loop"); + break; + }} + } + }); Ok(()) } } diff --git a/src/application/mod.rs b/src/application/mod.rs index 33d891a..456c34c 100644 --- a/src/application/mod.rs +++ b/src/application/mod.rs @@ -1,2 +1,2 @@ +pub mod datakey; pub mod user; -pub mod datakey; \ No newline at end of file diff --git a/src/application/user.rs b/src/application/user.rs index 7e4e549..2244a2a 100644 --- a/src/application/user.rs +++ b/src/application/user.rs @@ -14,32 +14,32 @@ * */ -use crate::domain::user::entity::User; use crate::domain::token::entity::Token; -use crate::domain::user::repository::Repository as UserRepository; use crate::domain::token::repository::Repository as TokenRepository; -use crate::util::error::{Result, Error}; +use crate::domain::user::entity::User; +use crate::domain::user::repository::Repository as UserRepository; +use crate::presentation::handler::control::model::token::dto::CreateTokenDTO; +use crate::presentation::handler::control::model::user::dto::UserIdentity; +use crate::util::cache::TimedFixedSizeCache; +use crate::util::error::{Error, Result}; +use crate::util::key::generate_api_token; use async_trait::async_trait; -use std::sync::Arc; -use std::sync::RwLock; use chrono::Utc; -use serde::{Deserialize}; use config::Config; -use reqwest::{header, Client, StatusCode}; -use crate::presentation::handler::control::model::user::dto::UserIdentity; use openidconnect::{ - Scope, - AuthenticationFlow, CsrfToken, Nonce, - core::CoreResponseType, core::CoreClient + core::CoreClient, core::CoreResponseType, AuthenticationFlow, CsrfToken, Nonce, Scope, }; -use openidconnect::{JsonWebKeySet, ClientId, AuthUrl, UserInfoUrl, TokenUrl, RedirectUrl, ClientSecret, IssuerUrl}; +use openidconnect::{ + AuthUrl, ClientId, ClientSecret, IssuerUrl, JsonWebKeySet, RedirectUrl, TokenUrl, UserInfoUrl, +}; +use reqwest::{header, Client, StatusCode}; +use serde::Deserialize; +use std::sync::Arc; +use std::sync::RwLock; use url::Url; -use crate::presentation::handler::control::model::token::dto::{CreateTokenDTO}; -use crate::util::key::{generate_api_token}; -use crate::util::cache::TimedFixedSizeCache; #[async_trait] -pub trait UserService: Send + Sync{ +pub trait UserService: Send + Sync { async fn get_token(&self, u: &UserIdentity) -> Result>; async fn delete_token(&self, u: &UserIdentity, id: i32) -> Result<()>; async fn get_valid_token(&self, token: &str) -> Result; @@ -69,14 +69,13 @@ pub struct OIDCConfig { pub token_url: String, pub redirect_uri: String, pub user_info_url: String, - pub auth_url: String + pub auth_url: String, } - pub struct DBUserService where R: UserRepository, - T: TokenRepository + T: TokenRepository, { user_repository: R, token_repository: T, @@ -86,14 +85,18 @@ where } impl DBUserService - where - R: UserRepository, - T: TokenRepository +where + R: UserRepository, + T: TokenRepository, { - pub fn new(user_repository: R, token_repository: T, config: Arc>) -> Result { + pub fn new( + user_repository: R, + token_repository: T, + config: Arc>, + ) -> Result { // TODO: remove me when openid connect library is ready we have to save OIDC in another object // due to we hacked several OIDC methods. - let oidc_config = OIDCConfig{ + let oidc_config = OIDCConfig { auth_url: config.read()?.get_string("oidc.auth_url")?, client_id: config.read()?.get_string("oidc.client_id")?, client_secret: config.read()?.get_string("oidc.client_secret")?, @@ -108,14 +111,15 @@ impl DBUserService AuthUrl::new(oidc_config.auth_url.clone())?, Some(TokenUrl::new(oidc_config.token_url.clone())?), Some(UserInfoUrl::new(oidc_config.user_info_url.clone())?), - JsonWebKeySet::default()).set_redirect_uri(RedirectUrl::new(oidc_config.redirect_uri.clone())?, - ); + JsonWebKeySet::default(), + ) + .set_redirect_uri(RedirectUrl::new(oidc_config.redirect_uri.clone())?); Ok(Self { user_repository, token_repository, oidc_config, client, - tokens: TimedFixedSizeCache::new(None, Some(20), None, None) + tokens: TimedFixedSizeCache::new(None, Some(20), None, None), }) } @@ -123,15 +127,21 @@ impl DBUserService // https://github.com/ramosbugs/openidconnect-rs/issues/100 async fn get_user_info(&self, access_token: &str) -> Result { let mut auth_header = header::HeaderMap::new(); - auth_header.insert("Authorization", header::HeaderValue::from_str( access_token)?); + auth_header.insert( + "Authorization", + header::HeaderValue::from_str(access_token)?, + ); match Client::builder().default_headers(auth_header).build() { Ok(client) => { - let resp: UserEmail = client.get(&self.oidc_config.user_info_url).send().await?.json().await?; + let resp: UserEmail = client + .get(&self.oidc_config.user_info_url) + .send() + .await? + .json() + .await?; Ok(resp) } - Err(err) => { - Err(Error::AuthError(err.to_string())) - } + Err(err) => Err(Error::AuthError(err.to_string())), } } @@ -139,22 +149,28 @@ impl DBUserService async fn get_access_token(&self, code: &str) -> Result { match Client::builder().build() { Ok(client) => { - let response= client.post(&self.oidc_config.token_url).query(&[ - ("client_id", self.oidc_config.client_id.as_str()), - ("client_secret", self.oidc_config.client_secret.as_str()), - ("code", code), - ("redirect_uri", self.oidc_config.redirect_uri.as_str()), - ("grant_type", "authorization_code")]).send().await?; + let response = client + .post(&self.oidc_config.token_url) + .query(&[ + ("client_id", self.oidc_config.client_id.as_str()), + ("client_secret", self.oidc_config.client_secret.as_str()), + ("code", code), + ("redirect_uri", self.oidc_config.redirect_uri.as_str()), + ("grant_type", "authorization_code"), + ]) + .send() + .await?; if response.status() != StatusCode::OK { - Err(Error::AuthError(format!("failed to get access token {}", response.text().await?))) + Err(Error::AuthError(format!( + "failed to get access token {}", + response.text().await? + ))) } else { let resp: AccessToken = response.json().await?; Ok(resp) } } - Err(err) => { - Err(Error::AuthError(err.to_string())) - } + Err(err) => Err(Error::AuthError(err.to_string())), } } } @@ -163,7 +179,7 @@ impl DBUserService impl UserService for DBUserService where R: UserRepository, - T: TokenRepository + T: TokenRepository, { async fn get_token(&self, user: &UserIdentity) -> Result> { self.token_repository.get_token_by_user_id(user.id).await @@ -172,7 +188,7 @@ where async fn delete_token(&self, u: &UserIdentity, id: i32) -> Result<()> { let token = self.token_repository.get_token_by_id(id).await?; if token.user_id != u.id { - return Err(Error::UnauthorizedError) + return Err(Error::UnauthorizedError); } self.token_repository.delete_by_user_and_id(id, u.id).await } @@ -180,23 +196,23 @@ where async fn get_valid_token(&self, token: &str) -> Result { let token = self.token_repository.get_token_by_value(token).await?; if token.expire_at.gt(&Utc::now()) { - return Ok(token) + return Ok(token); } Err(Error::TokenExpiredError(token.to_string())) } async fn validate_token(&self, token: &str) -> Result { if let Some(u) = self.tokens.get_user(token).await { - return Ok(u) + return Ok(u); } let tk = self.get_valid_token(token).await?; let user = self.user_repository.get_by_id(tk.user_id).await?; self.tokens.update_user(token, user.clone()).await?; - return Ok(user) + return Ok(user); } async fn save(&self, u: User) -> Result { - return self.user_repository.create(u).await + return self.user_repository.create(u).await; } async fn get_user_by_id(&self, id: i32) -> Result { @@ -214,13 +230,16 @@ where //return token with un-hashed value new.token = real_token; Ok(new) - } async fn get_login_url(&self) -> Result { - let (authorize_url, _, _) = self.client - .authorize_url(AuthenticationFlow::::AuthorizationCode, - CsrfToken::new_random, Nonce::new_random, ) + let (authorize_url, _, _) = self + .client + .authorize_url( + AuthenticationFlow::::AuthorizationCode, + CsrfToken::new_random, + Nonce::new_random, + ) .add_scope(Scope::new("email".to_string())) .add_scope(Scope::new("openid".to_string())) .add_scope(Scope::new("profile".to_string())) @@ -231,12 +250,17 @@ where async fn validate_user(&self, code: &str) -> Result { match self.get_access_token(code).await { Ok(token_response) => { - let id: User = User::new(self.get_user_info(&token_response.access_token).await?.email)?; - return self.user_repository.create(id).await - } - Err(err) => { - Err(Error::AuthError(format!("failed to get access token {}", err))) + let id: User = User::new( + self.get_user_info(&token_response.access_token) + .await? + .email, + )?; + return self.user_repository.create(id).await; } + Err(err) => Err(Error::AuthError(format!( + "failed to get access token {}", + err + ))), } } diff --git a/src/client/cmd/add.rs b/src/client/cmd/add.rs index 2f03916..a8cfb28 100644 --- a/src/client/cmd/add.rs +++ b/src/client/cmd/add.rs @@ -14,32 +14,32 @@ * */ -use clap::{Args}; +use super::traits::SignCommand; +use crate::client::sign_identity; use crate::util::error::Result; -use config::{Config}; +use clap::Args; +use config::Config; use regex::Regex; -use std::sync::{Arc, atomic::AtomicBool, RwLock}; -use super::traits::SignCommand; +use std::collections::HashMap; use std::path::PathBuf; +use std::sync::{atomic::AtomicBool, Arc, RwLock}; use tokio::runtime; -use crate::client::sign_identity; -use std::collections::HashMap; +use crate::client::file_handler::factory::FileHandlerFactory; use crate::util::error; -use async_channel::{bounded}; -use crate::util::sign::{SignType, FileType, KeyType}; use crate::util::options; -use crate::client::file_handler::factory::FileHandlerFactory; +use crate::util::sign::{FileType, KeyType, SignType}; +use async_channel::bounded; use crate::client::load_balancer::factory::ChannelFactory; use crate::client::worker::assembler::Assembler; +use crate::client::worker::key_fetcher::KeyFetcher; use crate::client::worker::signer::RemoteSigner; use crate::client::worker::splitter::Splitter; use crate::client::worker::traits::SignHandler; -use std::sync::atomic::{AtomicI32, Ordering}; -use crate::client::worker::key_fetcher::KeyFetcher; use crate::util::error::Error::CommandProcessFailed; use crate::util::key::file_exists; +use std::sync::atomic::{AtomicI32, Ordering}; lazy_static! { pub static ref FILE_EXTENSION: HashMap> = HashMap::from([ @@ -67,18 +67,23 @@ pub struct CommandAdd { #[arg(long)] #[arg(help = "create detached signature")] detached: bool, - #[arg(help = "specify the path which will be used for signing file and directory are supported")] + #[arg( + help = "specify the path which will be used for signing file and directory are supported" + )] path: String, #[arg(long)] #[arg(value_enum, default_value_t=SignType::Cms)] - #[arg(help = "specify the signature type, meaningful when key type is x509, EFI file supports `authenticode` only and KO file supports `cms` and `pkcs7`")] + #[arg( + help = "specify the signature type, meaningful when key type is x509, EFI file supports `authenticode` only and KO file supports `cms` and `pkcs7`" + )] sign_type: SignType, #[arg(long)] - #[arg(help = "force create rpm v3 signature, default is false. only support when file type is rpm")] + #[arg( + help = "force create rpm v3 signature, default is false. only support when file type is rpm" + )] rpm_v3: bool, } - #[derive(Clone)] pub struct CommandAddHandler { worker_threads: usize, @@ -89,7 +94,7 @@ pub struct CommandAddHandler { path: PathBuf, buffer_size: usize, signal: Arc, - config: Arc>, + config: Arc>, detached: bool, max_concurrency: usize, sign_type: SignType, @@ -98,74 +103,91 @@ pub struct CommandAddHandler { } impl CommandAddHandler { - fn get_sign_options(&self) -> HashMap { HashMap::from([ (options::DETACHED.to_string(), self.detached.to_string()), (options::KEY_TYPE.to_string(), self.key_type.to_string()), (options::SIGN_TYPE.to_string(), self.sign_type.to_string()), - (options::RPM_V3_SIGNATURE.to_string(), self.rpm_v3.to_string())]) + ( + options::RPM_V3_SIGNATURE.to_string(), + self.rpm_v3.to_string(), + ), + ]) } fn collect_file_candidates(&self) -> Result> { if self.path.is_dir() { let mut container = Vec::new(); for entry in walkdir::WalkDir::new(self.path.to_str().unwrap()) { match entry { - Ok(en)=> { + Ok(en) => { if en.metadata()?.is_dir() { - continue + continue; } if let Some(extension) = en.path().extension() { if self.file_candidates(extension.to_str().unwrap()).is_ok() { - container.push( - sign_identity::SignIdentity::new( - self.file_type.clone(), - en.path().to_path_buf(), - self.key_type.clone(), - self.key_name.clone(), - self.get_sign_options())); + container.push(sign_identity::SignIdentity::new( + self.file_type.clone(), + en.path().to_path_buf(), + self.key_type.clone(), + self.key_name.clone(), + self.get_sign_options(), + )); } } - }, - Err(err)=> { + } + Err(err) => { error!("failed to scan file {}, will be skipped", err); } } } return Ok(container); } else if self.file_candidates(self.path.extension().unwrap().to_str().unwrap())? { - return Ok(vec![sign_identity::SignIdentity::new( - self.file_type.clone(), self.path.clone(), self.key_type.clone(), self.key_name.clone(), self.get_sign_options())]); + return Ok(vec![sign_identity::SignIdentity::new( + self.file_type.clone(), + self.path.clone(), + self.key_type.clone(), + self.key_name.clone(), + self.get_sign_options(), + )]); } Err(error::Error::NoFileCandidateError) } fn file_candidates(&self, extension: &str) -> Result { - let collections = FILE_EXTENSION.get( - &self.file_type).ok_or_else(|| - error::Error::FileNotSupportError(extension.to_string(), self.file_type.to_string()))?; + let collections = FILE_EXTENSION.get(&self.file_type).ok_or_else(|| { + error::Error::FileNotSupportError(extension.to_string(), self.file_type.to_string()) + })?; for value in collections { let re = Regex::new(format!(r"^{}$", value).as_str()).unwrap(); if re.is_match(extension) { - return Ok(true) + return Ok(true); } } - Err(error::Error::FileNotSupportError(extension.to_string(), self.file_type.to_string())) + Err(error::Error::FileNotSupportError( + extension.to_string(), + self.file_type.to_string(), + )) } } - impl SignCommand for CommandAddHandler { type CommandValue = CommandAdd; - fn new(signal: Arc, config: Arc>, command: Self::CommandValue) -> Result { + fn new( + signal: Arc, + config: Arc>, + command: Self::CommandValue, + ) -> Result { let mut worker_threads = config.read()?.get_string("worker_threads")?.parse()?; if worker_threads == 0 { worker_threads = num_cpus::get(); } let working_dir = config.read()?.get_string("working_dir")?; if !file_exists(&working_dir) { - return Err(error::Error::FileFoundError(format!("working dir: {} not exists", working_dir))); + return Err(error::Error::FileFoundError(format!( + "working dir: {} not exists", + working_dir + ))); } let mut token = None; if let Ok(t) = config.read()?.get_string("token") { @@ -173,7 +195,7 @@ impl SignCommand for CommandAddHandler { token = Some(t); } } - Ok(CommandAddHandler{ + Ok(CommandAddHandler { worker_threads, buffer_size: config.read()?.get_string("buffer_size")?.parse()?, working_dir: config.read()?.get_string("working_dir")?, @@ -187,7 +209,7 @@ impl SignCommand for CommandAddHandler { max_concurrency: config.read()?.get_string("max_concurrency")?.parse()?, sign_type: command.sign_type, token, - rpm_v3: command.rpm_v3 + rpm_v3: command.rpm_v3, }) } @@ -212,7 +234,8 @@ impl SignCommand for CommandAddHandler { .worker_threads(self.worker_threads) .enable_io() .enable_time() - .build().unwrap(); + .build() + .unwrap(); let (split_s, split_r) = bounded::(self.max_concurrency); let (sign_s, sign_r) = bounded::(self.max_concurrency); let (assemble_s, assemble_r) = bounded::(self.max_concurrency); @@ -221,49 +244,48 @@ impl SignCommand for CommandAddHandler { let errored = runtime.block_on(async { let channel_provider = ChannelFactory::new(&lb_config).await; if let Err(err) = channel_provider { - return Some(err) + return Some(err); } let channel = channel_provider.unwrap().get_channel(); if let Err(err) = channel { - return Some(err) + return Some(err); } //fetch datakey attributes - info!("starting to fetch datakey [{}] {} attribute",self.key_type, self.key_name); + info!( + "starting to fetch datakey [{}] {} attribute", + self.key_type, self.key_name + ); let mut key_fetcher = KeyFetcher::new(channel.clone().unwrap(), self.token.clone()); - let key_attributes; - match key_fetcher.get_key_attributes(&self.key_name, &self.key_type.to_string()).await { - Ok(attributes) => { - key_attributes = attributes - } - Err(err) => { - return Some(err) - } - } + let key_attributes = match key_fetcher + .get_key_attributes(&self.key_name, &self.key_type.to_string()) + .await + { + Ok(attributes) => attributes, + Err(err) => return Some(err), + }; //collect file candidates - let files; - match self.collect_file_candidates() { - Ok(f) => { - files = f - } - Err(err) => { - return Some(err) - } - } + let files = match self.collect_file_candidates() { + Ok(f) => f, + Err(err) => return Some(err), + }; info!("starting to sign {} files", files.len()); - let mut signer = RemoteSigner::new(channel.unwrap(), self.buffer_size, self.token.clone()); + let mut signer = + RemoteSigner::new(channel.unwrap(), self.buffer_size, self.token.clone()); //split file - let send_handlers = files.into_iter().map(|file|{ - let task_split_s = split_s.clone(); - tokio::spawn(async move { - let file_name = format!("{}", file.file_path.as_path().display()); - if let Err(err) = task_split_s.send(file).await { - error!("failed to send file for splitting: {}", err); - } else { - info!("starting to split file: {}", file_name); - } - + let send_handlers = files + .into_iter() + .map(|file| { + let task_split_s = split_s.clone(); + tokio::spawn(async move { + let file_name = format!("{}", file.file_path.as_path().display()); + if let Err(err) = task_split_s.send(file).await { + error!("failed to send file for splitting: {}", err); + } else { + info!("starting to split file: {}", file_name); + } + }) }) - }).collect::>(); + .collect::>(); //do file split let task_sign_s = sign_s.clone(); let s_key_attributes = key_attributes.clone(); @@ -274,10 +296,10 @@ impl SignCommand for CommandAddHandler { Ok(identity) => { let mut splitter = Splitter::new(s_key_attributes.clone()); splitter.handle(identity, task_sign_s.clone()).await; - }, + } Err(_) => { info!("split channel closed"); - return + return; } } } @@ -290,10 +312,10 @@ impl SignCommand for CommandAddHandler { match sign_identity { Ok(identity) => { signer.handle(identity, task_assemble_s.clone()).await; - }, + } Err(_) => { info!("sign channel closed"); - return + return; } } } @@ -306,12 +328,13 @@ impl SignCommand for CommandAddHandler { let sign_identity = assemble_r.recv().await; match sign_identity { Ok(identity) => { - let mut assembler = Assembler::new( working_dir.clone(), key_attributes.clone()); + let mut assembler = + Assembler::new(working_dir.clone(), key_attributes.clone()); assembler.handle(identity, task_collect_s.clone()).await; - }, + } Err(_) => { info!("assemble channel closed"); - return + return; } } } @@ -325,18 +348,23 @@ impl SignCommand for CommandAddHandler { match sign_identity { Ok(identity) => { if identity.error.borrow().clone().is_err() { - error!("failed to sign file {} due to error {:?}", + error!( + "failed to sign file {} due to error {:?}", identity.file_path.as_path().display(), - identity.error.borrow().clone().err()); - failed_files_c.fetch_add( 1, Ordering::SeqCst); + identity.error.borrow().clone().err() + ); + failed_files_c.fetch_add(1, Ordering::SeqCst); } else { - info!("successfully signed file {}", identity.file_path.as_path().display()); - succeed_files_c.fetch_add( 1, Ordering::SeqCst); + info!( + "successfully signed file {}", + identity.file_path.as_path().display() + ); + succeed_files_c.fetch_add(1, Ordering::SeqCst); } - }, + } Err(_) => { info!("collect channel closed"); - return + return; } } } @@ -344,7 +372,10 @@ impl SignCommand for CommandAddHandler { // wait for finish for h in send_handlers { if let Err(error) = h.await { - return Some(CommandProcessFailed(format!("failed to wait for send handler: {}", error.to_string()))) + return Some(CommandProcessFailed(format!( + "failed to wait for send handler: {}", + error.to_string() + ))); } } for (key, channel, worker) in [ @@ -355,19 +386,26 @@ impl SignCommand for CommandAddHandler { ] { drop(channel); if let Err(error) = worker.await { - return Some(CommandProcessFailed(format!("failed to wait for: {0} handler to finish: {1}", key, error.to_string()))) + return Some(CommandProcessFailed(format!( + "failed to wait for: {0} handler to finish: {1}", + key, + error.to_string() + ))); } } - info!("Successfully signed {} files failed {} files", - succeed_files.load(Ordering::Relaxed), failed_files.load(Ordering::Relaxed)); + info!( + "Successfully signed {} files failed {} files", + succeed_files.load(Ordering::Relaxed), + failed_files.load(Ordering::Relaxed) + ); info!("sign files process finished"); None }); if let Some(err) = errored { - return Err(err) + return Err(err); } if failed_files.load(Ordering::Relaxed) != 0 { - return Ok(false) + return Ok(false); } Ok(true) } diff --git a/src/client/cmd/mod.rs b/src/client/cmd/mod.rs index dbbd277..986adfb 100644 --- a/src/client/cmd/mod.rs +++ b/src/client/cmd/mod.rs @@ -1,2 +1,2 @@ pub mod add; -pub mod traits; \ No newline at end of file +pub mod traits; diff --git a/src/client/cmd/traits.rs b/src/client/cmd/traits.rs index a88ddd6..53f3e27 100644 --- a/src/client/cmd/traits.rs +++ b/src/client/cmd/traits.rs @@ -15,13 +15,16 @@ */ use crate::util::error::Result; -use std::sync::{RwLock, Arc, atomic::AtomicBool}; use config::Config; - +use std::sync::{atomic::AtomicBool, Arc, RwLock}; pub trait SignCommand: Clone { type CommandValue; - fn new(signal: Arc, config: Arc>, command: Self::CommandValue) -> Result; + fn new( + signal: Arc, + config: Arc>, + command: Self::CommandValue, + ) -> Result; fn validate(&self) -> Result<()>; fn handle(&self) -> Result; -} \ No newline at end of file +} diff --git a/src/client/file_handler/efi.rs b/src/client/file_handler/efi.rs index 167cbdc..cb1c17f 100644 --- a/src/client/file_handler/efi.rs +++ b/src/client/file_handler/efi.rs @@ -1,14 +1,14 @@ use super::traits::FileHandler; -use crate::util::options; -use crate::util::sign::{SignType, KeyType}; use crate::util::error::{Error, Result}; +use crate::util::options; +use crate::util::sign::{KeyType, SignType}; use async_trait::async_trait; use efi_signer::{DigestAlgorithm, EfiImage}; use std::collections::HashMap; use std::fs::read; +use std::io::Write; use std::path::PathBuf; use uuid::Uuid; -use std::io::Write; pub struct EfiFileHandler {} impl EfiFileHandler { @@ -50,7 +50,7 @@ impl FileHandler for EfiFileHandler { &self, path: &PathBuf, _sign_options: &mut HashMap, - _key_attributes: &HashMap + _key_attributes: &HashMap, ) -> Result>> { let buf = read(path)?; let pe = EfiImage::parse(&buf)?; @@ -58,7 +58,11 @@ impl FileHandler for EfiFileHandler { Some(algo) => pe.compute_digest(algo)?, None => pe.compute_digest(DigestAlgorithm::Sha256)?, }; - info!("file {} digest {:x?}", path.as_path().display().to_string(), digest.as_slice()); + info!( + "file {} digest {:x?}", + path.as_path().display().to_string(), + digest.as_slice() + ); Ok(vec![digest]) } @@ -68,13 +72,13 @@ impl FileHandler for EfiFileHandler { data: Vec>, temp_dir: &PathBuf, _sign_options: &HashMap, - _key_attributes: &HashMap + _key_attributes: &HashMap, ) -> Result<(String, String)> { let temp_file = temp_dir.join(Uuid::new_v4().to_string()); let buf = read(path)?; let pe = EfiImage::parse(&buf)?; - let mut signatures :Vec = Vec::new(); + let mut signatures: Vec = Vec::new(); for d in data.iter() { signatures.push(efi_signer::Signature::decode(d)?); @@ -164,7 +168,13 @@ mod test { let temp_dir = env::temp_dir(); let result = handler - .assemble_data(&path, vec![signature_buf], &temp_dir, &options, &HashMap::new()) + .assemble_data( + &path, + vec![signature_buf], + &temp_dir, + &options, + &HashMap::new(), + ) .await; assert!(result.is_ok()); let (temp_file, file_name) = result.expect("efi sign should work"); diff --git a/src/client/file_handler/factory.rs b/src/client/file_handler/factory.rs index 5954b15..05e21b7 100644 --- a/src/client/file_handler/factory.rs +++ b/src/client/file_handler/factory.rs @@ -14,31 +14,22 @@ * */ -use super::rpm::RpmFileHandler; use super::efi::EfiFileHandler; use super::generic::GenericFileHandler; use super::kernel_module::KernelModuleFileHandler; -use crate::util::sign::FileType; +use super::rpm::RpmFileHandler; use super::traits::FileHandler; +use crate::util::sign::FileType; -pub struct FileHandlerFactory { -} +pub struct FileHandlerFactory {} impl FileHandlerFactory { pub fn get_handler(file_type: &FileType) -> Box { match file_type { - FileType::Rpm => { - Box::new(RpmFileHandler::new()) - }, - FileType::Generic => { - Box::new(GenericFileHandler::new()) - }, - FileType::KernelModule => { - Box::new(KernelModuleFileHandler::new()) - }, - FileType::EfiImage => { - Box::new(EfiFileHandler::new()) - }, + FileType::Rpm => Box::new(RpmFileHandler::new()), + FileType::Generic => Box::new(GenericFileHandler::new()), + FileType::KernelModule => Box::new(KernelModuleFileHandler::new()), + FileType::EfiImage => Box::new(EfiFileHandler::new()), } } -} \ No newline at end of file +} diff --git a/src/client/file_handler/generic.rs b/src/client/file_handler/generic.rs index bafb962..5df68cd 100644 --- a/src/client/file_handler/generic.rs +++ b/src/client/file_handler/generic.rs @@ -15,15 +15,15 @@ */ use super::traits::FileHandler; -use crate::util::sign::{KeyType}; use crate::util::error::Result; +use crate::util::sign::KeyType; use async_trait::async_trait; use std::path::PathBuf; use tokio::fs; use uuid::Uuid; -use crate::util::options; use crate::util::error::Error; +use crate::util::options; use std::collections::HashMap; //Stands for ASCII Armored file @@ -66,7 +66,7 @@ impl FileHandler for GenericFileHandler { data: Vec>, temp_dir: &PathBuf, _sign_options: &HashMap, - _key_attributes: &HashMap + _key_attributes: &HashMap, ) -> Result<(String, String)> { let temp_file = temp_dir.join(Uuid::new_v4().to_string()); //convert bytes into string @@ -121,7 +121,9 @@ mod test { let path = PathBuf::from("./test_data/test.txt"); let data = vec![vec![1, 2, 3]]; let temp_dir = env::temp_dir(); - let result = handler.assemble_data(&path, data, &temp_dir, &options, &HashMap::new()).await; + let result = handler + .assemble_data(&path, data, &temp_dir, &options, &HashMap::new()) + .await; assert!(result.is_ok()); let (temp_file, file_name) = result.expect("invoke assemble data should work"); assert_eq!(temp_file.starts_with(temp_dir.to_str().unwrap()), true); diff --git a/src/client/file_handler/kernel_module.rs b/src/client/file_handler/kernel_module.rs index 21a2092..ce43700 100644 --- a/src/client/file_handler/kernel_module.rs +++ b/src/client/file_handler/kernel_module.rs @@ -28,10 +28,10 @@ use std::io::{Read, Seek, Write}; use std::os::raw::{c_uchar, c_uint}; use uuid::Uuid; -use crate::util::options; -use crate::util::sign::{SignType, KeyType}; use crate::util::error::Error; +use crate::util::options; use crate::util::options::DETACHED; +use crate::util::sign::{KeyType, SignType}; const FILE_EXTENSION: &str = "p7s"; const PKEY_ID_PKCS7: c_uchar = 2; @@ -99,7 +99,11 @@ impl KernelModuleFileHandler { Ok(()) } - pub fn get_raw_content(&self, path: &PathBuf, sign_options: &mut HashMap) -> Result> { + pub fn get_raw_content( + &self, + path: &PathBuf, + sign_options: &mut HashMap, + ) -> Result> { let raw_content = fs::read(path)?; let mut file = fs::File::open(path)?; if file.metadata()?.len() <= SIGNATURE_SIZE as u64 { @@ -113,7 +117,8 @@ impl KernelModuleFileHandler { Ok(ending) => { return if ending == MAGIC_NUMBER { file.seek(io::SeekFrom::End(-(SIGNATURE_SIZE as i64)))?; - let mut signature_meta: [u8; SIGNATURE_SIZE - MAGIC_NUMBER_SIZE] = [0; SIGNATURE_SIZE - MAGIC_NUMBER_SIZE]; + let mut signature_meta: [u8; SIGNATURE_SIZE - MAGIC_NUMBER_SIZE] = + [0; SIGNATURE_SIZE - MAGIC_NUMBER_SIZE]; let _ = file.read(&mut signature_meta)?; //decode kernel module signature struct let signature: ModuleSignature = bincode::decode_from_slice( @@ -162,8 +167,9 @@ impl FileHandler for KernelModuleFileHandler { } if let Some(sign_type) = sign_options.get(options::SIGN_TYPE) { - if sign_type != SignType::Cms.to_string().as_str() && - sign_type != SignType::PKCS7.to_string().as_str() { + if sign_type != SignType::Cms.to_string().as_str() + && sign_type != SignType::PKCS7.to_string().as_str() + { return Err(Error::InvalidArgumentError( "kernel module file only support cms or pkcs7 sign type".to_string(), )); @@ -177,7 +183,7 @@ impl FileHandler for KernelModuleFileHandler { &self, path: &PathBuf, sign_options: &mut HashMap, - _key_attributes: &HashMap + _key_attributes: &HashMap, ) -> Result>> { Ok(vec![self.get_raw_content(path, sign_options)?]) } @@ -189,7 +195,7 @@ impl FileHandler for KernelModuleFileHandler { data: Vec>, temp_dir: &PathBuf, sign_options: &HashMap, - _key_attributes: &HashMap + _key_attributes: &HashMap, ) -> Result<(String, String)> { let temp_file = temp_dir.join(Uuid::new_v4().to_string()); //convert bytes into string @@ -213,17 +219,20 @@ impl FileHandler for KernelModuleFileHandler { #[cfg(test)] mod test { use super::*; - use std::env; use rand::Rng; + use std::env; - fn generate_signed_kernel_module(length: usize, incorrect_length: bool) -> Result<(String, Vec)> { + fn generate_signed_kernel_module( + length: usize, + incorrect_length: bool, + ) -> Result<(String, Vec)> { let mut rng = rand::thread_rng(); let temp_file = env::temp_dir().join(Uuid::new_v4().to_string()); let mut file = fs::File::create(temp_file.clone())?; let raw_content: Vec = (0..length).map(|_| rng.gen_range(0..=255)).collect(); file.write_all(&raw_content)?; //append fake signature - let signature = vec![1,2,3,4,5,6]; + let signature = vec![1, 2, 3, 4, 5, 6]; file.write_all(&signature)?; let mut size = signature.len(); if incorrect_length { @@ -235,7 +244,8 @@ mod test { signature, config::standard() .with_fixed_int_encoding() - .with_big_endian())?)?; + .with_big_endian(), + )?)?; file.write_all(MAGIC_NUMBER.as_bytes())?; Ok((temp_file.display().to_string(), raw_content)) } @@ -253,10 +263,13 @@ mod test { fn test_get_raw_content_with_small_unsigned_content() { let mut sign_options = HashMap::new(); let file_handler = KernelModuleFileHandler::new(); - let (name, original_content) = generate_unsigned_kernel_module(SIGNATURE_SIZE-1).expect("generate unsigned kernel module failed"); + let (name, original_content) = generate_unsigned_kernel_module(SIGNATURE_SIZE - 1) + .expect("generate unsigned kernel module failed"); let path = PathBuf::from(name); - let raw_content = file_handler.get_raw_content(&path, &mut sign_options).expect("get raw content failed"); - assert_eq!(raw_content.len(), SIGNATURE_SIZE-1); + let raw_content = file_handler + .get_raw_content(&path, &mut sign_options) + .expect("get raw content failed"); + assert_eq!(raw_content.len(), SIGNATURE_SIZE - 1); assert_eq!(original_content, raw_content); } @@ -264,10 +277,13 @@ mod test { fn test_get_raw_content_with_large_unsigned_content() { let mut sign_options = HashMap::new(); let file_handler = KernelModuleFileHandler::new(); - let (name, original_content) = generate_unsigned_kernel_module(SIGNATURE_SIZE+100).expect("generate unsigned kernel module failed"); + let (name, original_content) = generate_unsigned_kernel_module(SIGNATURE_SIZE + 100) + .expect("generate unsigned kernel module failed"); let path = PathBuf::from(name); - let raw_content = file_handler.get_raw_content(&path, &mut sign_options).expect("get raw content failed"); - assert_eq!(raw_content.len(), SIGNATURE_SIZE+100); + let raw_content = file_handler + .get_raw_content(&path, &mut sign_options) + .expect("get raw content failed"); + assert_eq!(raw_content.len(), SIGNATURE_SIZE + 100); assert_eq!(original_content, raw_content); } @@ -275,9 +291,12 @@ mod test { fn test_get_raw_content_with_signed_content() { let mut sign_options = HashMap::new(); let file_handler = KernelModuleFileHandler::new(); - let (name, original_content) = generate_signed_kernel_module(100,false).expect("generate signed kernel module failed"); + let (name, original_content) = generate_signed_kernel_module(100, false) + .expect("generate signed kernel module failed"); let path = PathBuf::from(name); - let raw_content = file_handler.get_raw_content(&path, &mut sign_options).expect("get raw content failed"); + let raw_content = file_handler + .get_raw_content(&path, &mut sign_options) + .expect("get raw content failed"); assert_eq!(raw_content.len(), 100); assert_eq!(original_content, raw_content); } @@ -286,7 +305,8 @@ mod test { fn test_get_raw_content_with_invalid_signed_content() { let mut sign_options = HashMap::new(); let file_handler = KernelModuleFileHandler::new(); - let (name, _) = generate_signed_kernel_module(100,true).expect("generate signed kernel module failed"); + let (name, _) = + generate_signed_kernel_module(100, true).expect("generate signed kernel module failed"); let path = PathBuf::from(name); let result = file_handler.get_raw_content(&path, &mut sign_options); assert_eq!( @@ -308,7 +328,10 @@ mod test { ); options.insert(options::KEY_TYPE.to_string(), KeyType::X509EE.to_string()); - options.insert(options::SIGN_TYPE.to_string(), SignType::Authenticode.to_string()); + options.insert( + options::SIGN_TYPE.to_string(), + SignType::Authenticode.to_string(), + ); let result = handler.validate_options(&options); assert!(result.is_err()); assert_eq!( @@ -316,7 +339,6 @@ mod test { "invalid argument: kernel module file only support cms or pkcs7 sign type" ); - options.insert(options::SIGN_TYPE.to_string(), SignType::Cms.to_string()); let result = handler.validate_options(&options); assert!(result.is_ok()); @@ -334,7 +356,9 @@ mod test { let path = PathBuf::from("./test_data/test.ko"); let data = vec![vec![1, 2, 3]]; let temp_dir = env::temp_dir(); - let result = handler.assemble_data(&path, data, &temp_dir, &options, &HashMap::new()).await; + let result = handler + .assemble_data(&path, data, &temp_dir, &options, &HashMap::new()) + .await; assert!(result.is_ok()); let (temp_file, file_name) = result.expect("invoke assemble data should work"); assert_eq!(temp_file.starts_with(temp_dir.to_str().unwrap()), true); @@ -348,16 +372,21 @@ mod test { let handler = KernelModuleFileHandler::new(); let mut options = HashMap::new(); options.insert(DETACHED.to_string(), "false".to_string()); - let (name, raw_content) = generate_signed_kernel_module(100,false).expect("generate signed kernel module failed"); + let (name, raw_content) = generate_signed_kernel_module(100, false) + .expect("generate signed kernel module failed"); let path = PathBuf::from(name.clone()); let data = vec![vec![1, 2, 3]]; let temp_dir = env::temp_dir(); - let result = handler.assemble_data(&path, data, &temp_dir, &options, &HashMap::new()).await; + let result = handler + .assemble_data(&path, data, &temp_dir, &options, &HashMap::new()) + .await; assert!(result.is_ok()); let (temp_file, file_name) = result.expect("invoke assemble data should work"); assert_eq!(temp_file.starts_with(temp_dir.to_str().unwrap()), true); assert_eq!(file_name, name); - let result = handler.get_raw_content(&PathBuf::from(temp_file), &mut options).expect("get raw content failed"); + let result = handler + .get_raw_content(&PathBuf::from(temp_file), &mut options) + .expect("get raw content failed"); assert_eq!(result, raw_content); } @@ -365,12 +394,14 @@ mod test { async fn test_split_content() { let mut sign_options = HashMap::new(); let file_handler = KernelModuleFileHandler::new(); - let (name, original_content) = generate_unsigned_kernel_module(SIGNATURE_SIZE-1).expect("generate unsigned kernel module failed"); + let (name, original_content) = generate_unsigned_kernel_module(SIGNATURE_SIZE - 1) + .expect("generate unsigned kernel module failed"); let path = PathBuf::from(name); - let raw_content = file_handler.split_data(&path, &mut sign_options, &HashMap::new()).await.expect("get raw content failed"); - assert_eq!(raw_content[0].len(), SIGNATURE_SIZE-1); + let raw_content = file_handler + .split_data(&path, &mut sign_options, &HashMap::new()) + .await + .expect("get raw content failed"); + assert_eq!(raw_content[0].len(), SIGNATURE_SIZE - 1); assert_eq!(original_content, raw_content[0]); } - } - diff --git a/src/client/file_handler/mod.rs b/src/client/file_handler/mod.rs index 6b2a569..0ce9995 100644 --- a/src/client/file_handler/mod.rs +++ b/src/client/file_handler/mod.rs @@ -1,6 +1,6 @@ -pub mod rpm; pub mod efi; -pub mod traits; pub mod factory; pub mod generic; -pub mod kernel_module; \ No newline at end of file +pub mod kernel_module; +pub mod rpm; +pub mod traits; diff --git a/src/client/file_handler/rpm.rs b/src/client/file_handler/rpm.rs index d2f6332..16fcc21 100644 --- a/src/client/file_handler/rpm.rs +++ b/src/client/file_handler/rpm.rs @@ -14,61 +14,60 @@ * */ -use std::collections::HashMap; -use std::path::PathBuf; use super::traits::FileHandler; -use async_trait::async_trait; use crate::util::error::Result; +use async_trait::async_trait; +use rpm::{Digests, Header, IndexSignatureTag, IndexTag, Package}; +use std::collections::HashMap; use std::fs::File; -use std::str::FromStr; #[allow(unused_imports)] use std::io::{BufReader, Read}; -use rpm::{Header, IndexSignatureTag, Package, Digests, IndexTag}; +use std::path::PathBuf; +use std::str::FromStr; -use uuid::Uuid; use crate::domain::datakey::plugins::openpgp::OpenPGPKeyType; -use crate::util::{options}; -use crate::util::sign::KeyType; use crate::util::error::Error; +use crate::util::options; +use crate::util::sign::KeyType; +use uuid::Uuid; #[derive(Clone)] -pub struct RpmFileHandler { - -} - +pub struct RpmFileHandler {} impl RpmFileHandler { pub fn new() -> Self { - Self { - - } + Self {} } //defaults to RSA fn get_key_type(&self, key_attributes: &HashMap) -> OpenPGPKeyType { match key_attributes.get(options::KEY_TYPE) { - Some(value) => { - match OpenPGPKeyType::from_str(value) { - Ok(key_type) => { - key_type - } - Err(_) => { - OpenPGPKeyType::Rsa - } - } - } - None => { - OpenPGPKeyType::Rsa - } + Some(value) => match OpenPGPKeyType::from_str(value) { + Ok(key_type) => key_type, + Err(_) => OpenPGPKeyType::Rsa, + }, + None => OpenPGPKeyType::Rsa, } } - fn generate_v3_signature(&self, sign_options: &HashMap, package: &Package) -> bool { + fn generate_v3_signature( + &self, + sign_options: &HashMap, + package: &Package, + ) -> bool { if let Some(v3_format) = sign_options.get(options::RPM_V3_SIGNATURE) { if v3_format == "true" { return true; } } - if package.metadata.header.entry_is_present(IndexTag::RPMTAG_PAYLOADDIGEST) || package.metadata.header.entry_is_present(IndexTag::RPMTAG_PAYLOADDIGESTALT) { + if package + .metadata + .header + .entry_is_present(IndexTag::RPMTAG_PAYLOADDIGEST) + || package + .metadata + .header + .entry_is_present(IndexTag::RPMTAG_PAYLOADDIGESTALT) + { return false; } true @@ -78,16 +77,19 @@ impl RpmFileHandler { //todo: figure our why is much slower when async read & write with tokio is enabled. #[async_trait] impl FileHandler for RpmFileHandler { - fn validate_options(&self, sign_options: &HashMap) -> Result<()> { if let Some(detached) = sign_options.get(options::DETACHED) { if detached == "true" { - return Err(Error::InvalidArgumentError("rpm file only support inside signature".to_string())) + return Err(Error::InvalidArgumentError( + "rpm file only support inside signature".to_string(), + )); } } if let Some(key_type) = sign_options.get(options::KEY_TYPE) { if key_type != KeyType::Pgp.to_string().as_str() { - return Err(Error::InvalidArgumentError("rpm file only support pgp signature".to_string())) + return Err(Error::InvalidArgumentError( + "rpm file only support pgp signature".to_string(), + )); } } Ok(()) @@ -102,7 +104,8 @@ impl FileHandler for RpmFileHandler { &self, path: &PathBuf, sign_options: &mut HashMap, - key_attributes: &HashMap) -> Result>> { + key_attributes: &HashMap, + ) -> Result>> { let file = File::open(path)?; let package = Package::parse(&mut BufReader::new(file))?; let mut header_bytes = Vec::::with_capacity(1024); @@ -110,28 +113,28 @@ impl FileHandler for RpmFileHandler { package.metadata.header.write(&mut header_bytes)?; return if self.get_key_type(key_attributes) == OpenPGPKeyType::Eddsa { if self.generate_v3_signature(sign_options, &package) { - Err(Error::InvalidArgumentError("eddsa key does not support v3 signature".to_string())) + Err(Error::InvalidArgumentError( + "eddsa key does not support v3 signature".to_string(), + )) } else { Ok(vec![header_bytes]) } + } else if self.generate_v3_signature(sign_options, &package) { + let mut header_and_content = Vec::new(); + header_and_content.extend(header_bytes.clone()); + header_and_content.extend(package.content.clone()); + Ok(vec![header_bytes, header_and_content]) } else { - if self.generate_v3_signature(sign_options, &package) { - let mut header_and_content = Vec::new(); - header_and_content.extend(header_bytes.clone()); - header_and_content.extend(package.content.clone()); - Ok(vec![header_bytes, header_and_content]) - } else { - Ok(vec![header_bytes]) - } - } + Ok(vec![header_bytes]) + }; } async fn assemble_data( &self, - path: & PathBuf, + path: &PathBuf, data: Vec>, temp_dir: &PathBuf, sign_options: &HashMap, - key_attributes: &HashMap + key_attributes: &HashMap, ) -> Result<(String, String)> { let temp_rpm = temp_dir.join(Uuid::new_v4().to_string()); let file = File::open(path)?; @@ -142,45 +145,49 @@ impl FileHandler for RpmFileHandler { header_digest_sha256, header_digest_sha1, header_and_content_digest, - } = Package::create_sig_header_digests(header_bytes.as_slice(), &package.content.as_slice())?; + } = Package::create_sig_header_digests( + header_bytes.as_slice(), + &package.content.as_slice(), + )?; let key_type = self.get_key_type(key_attributes); let builder = match key_type { OpenPGPKeyType::Rsa => { if self.generate_v3_signature(sign_options, &package) { - Header::::builder().add_digest( - &header_digest_sha1, - &header_digest_sha256, - &header_and_content_digest, - ).add_rsa_signature_legacy( - data[0].as_slice(), - data[1].as_slice() - ) + Header::::builder() + .add_digest( + &header_digest_sha1, + &header_digest_sha256, + &header_and_content_digest, + ) + .add_rsa_signature_legacy(data[0].as_slice(), data[1].as_slice()) } else { - Header::::builder().add_digest( - &header_digest_sha1, - &header_digest_sha256, - &header_and_content_digest, - ).add_rsa_signature( - data[0].as_slice(), - ) + Header::::builder() + .add_digest( + &header_digest_sha1, + &header_digest_sha256, + &header_and_content_digest, + ) + .add_rsa_signature(data[0].as_slice()) } } - OpenPGPKeyType::Eddsa => { - Header::::builder().add_digest( + OpenPGPKeyType::Eddsa => Header::::builder() + .add_digest( &header_digest_sha1, &header_digest_sha256, &header_and_content_digest, - ).add_eddsa_signature( - data[0].as_slice(), ) - } + .add_eddsa_signature(data[0].as_slice()), }; - package.metadata.signature = builder.build(header_bytes.as_slice().len() + package.content.len()); + package.metadata.signature = + builder.build(header_bytes.as_slice().len() + package.content.len()); //save data into temp file let mut output = File::create(temp_rpm.clone())?; package.write(&mut output)?; - Ok((temp_rpm.as_path().display().to_string(), format!("{}", path.display()))) + Ok(( + temp_rpm.as_path().display().to_string(), + format!("{}", path.display()), + )) } } @@ -193,18 +200,26 @@ mod test { fn get_signed_rpm() -> Result { let current_dir = env::current_dir().expect("get current dir failed"); - Ok(PathBuf::from(current_dir.join("test_assets").join("Imath-3.1.4-1.oe2303.x86_64.rpm"))) + Ok(PathBuf::from( + current_dir + .join("test_assets") + .join("Imath-3.1.4-1.oe2303.x86_64.rpm"), + )) } fn get_signed_rpm_without_payload_digest() -> Result { let current_dir = env::current_dir().expect("get current dir failed"); - Ok(PathBuf::from(current_dir.join("test_assets").join("at-3.1.10-43.el6.x86_64.rpm"))) + Ok(PathBuf::from( + current_dir + .join("test_assets") + .join("at-3.1.10-43.el6.x86_64.rpm"), + )) } fn generate_invalid_rpm() -> Result { let temp_file = env::temp_dir().join(Uuid::new_v4().to_string()); let mut file = File::create(temp_file.clone())?; - let content = vec![1,2,3,4]; + let content = vec![1, 2, 3, 4]; file.write_all(&content)?; Ok(temp_file) } @@ -251,7 +266,10 @@ mod test { let mut sign_options = HashMap::new(); let file_handler = RpmFileHandler::new(); let path = get_signed_rpm().expect("get signed rpm failed"); - let raw_content = file_handler.split_data(&path, &mut sign_options, &HashMap::new()).await.expect("get raw content failed"); + let raw_content = file_handler + .split_data(&path, &mut sign_options, &HashMap::new()) + .await + .expect("get raw content failed"); assert_eq!(raw_content.len(), 1); assert_eq!(raw_content[0].len(), 4325); } @@ -262,7 +280,10 @@ mod test { sign_options.insert(options::RPM_V3_SIGNATURE.to_string(), true.to_string()); let file_handler = RpmFileHandler::new(); let path = get_signed_rpm().expect("get signed rpm failed"); - let raw_content = file_handler.split_data(&path, &mut sign_options, &HashMap::new()).await.expect("get raw content failed"); + let raw_content = file_handler + .split_data(&path, &mut sign_options, &HashMap::new()) + .await + .expect("get raw content failed"); assert_eq!(raw_content.len(), 2); assert_eq!(raw_content[0].len(), 4325); assert_eq!(raw_content[1].len(), 67757); @@ -273,10 +294,16 @@ mod test { let mut sign_options = HashMap::new(); sign_options.insert(options::RPM_V3_SIGNATURE.to_string(), true.to_string()); let mut key_attributes = HashMap::new(); - key_attributes.insert(options::KEY_TYPE.to_string(), OpenPGPKeyType::Eddsa.to_string()); + key_attributes.insert( + options::KEY_TYPE.to_string(), + OpenPGPKeyType::Eddsa.to_string(), + ); let file_handler = RpmFileHandler::new(); let path = get_signed_rpm().expect("get signed rpm failed"); - file_handler.split_data(&path, &mut sign_options, &key_attributes).await.expect_err("eddsa key does not support v3 signature"); + file_handler + .split_data(&path, &mut sign_options, &key_attributes) + .await + .expect_err("eddsa key does not support v3 signature"); } #[tokio::test] @@ -284,7 +311,10 @@ mod test { let mut sign_options = HashMap::new(); let file_handler = RpmFileHandler::new(); let path = get_signed_rpm_without_payload_digest().expect("get signed rpm failed"); - let raw_content = file_handler.split_data(&path, &mut sign_options, &HashMap::new()).await.expect("get raw content failed"); + let raw_content = file_handler + .split_data(&path, &mut sign_options, &HashMap::new()) + .await + .expect("get raw content failed"); assert_eq!(raw_content.len(), 2); assert_eq!(raw_content[0].len(), 21484); assert_eq!(raw_content[1].len(), 60304); @@ -295,7 +325,10 @@ mod test { let mut sign_options = HashMap::new(); let file_handler = RpmFileHandler::new(); let path = generate_invalid_rpm().expect("generate invalid rpm failed"); - let _raw_content = file_handler.split_data(&path, &mut sign_options, &HashMap::new()).await.expect_err("split invalid rpm file would failed"); + let _raw_content = file_handler + .split_data(&path, &mut sign_options, &HashMap::new()) + .await + .expect_err("split invalid rpm file would failed"); } #[tokio::test] @@ -303,11 +336,16 @@ mod test { let mut sign_options = HashMap::new(); let file_handler = RpmFileHandler::new(); let path = generate_signed_rpm().expect("generate signed rpm failed"); - let fake_signature = vec![vec![1,2,3,4], vec![1,2,3,4]]; - let _raw_content = file_handler.assemble_data(&path, fake_signature, &env::temp_dir(), &mut sign_options, &HashMap::new()).await.expect("assemble data failed"); + let fake_signature = vec![vec![1, 2, 3, 4], vec![1, 2, 3, 4]]; + let _raw_content = file_handler + .assemble_data( + &path, + fake_signature, + &env::temp_dir(), + &mut sign_options, + &HashMap::new(), + ) + .await + .expect("assemble data failed"); } - } - - - diff --git a/src/client/file_handler/traits.rs b/src/client/file_handler/traits.rs index cb50995..b1848f7 100644 --- a/src/client/file_handler/traits.rs +++ b/src/client/file_handler/traits.rs @@ -27,7 +27,7 @@ pub trait FileHandler: Send + Sync { &self, path: &PathBuf, _sign_options: &mut HashMap, - _key_attributes: &HashMap + _key_attributes: &HashMap, ) -> Result>> { let content = fs::read(path).await?; Ok(vec![content]) @@ -39,6 +39,6 @@ pub trait FileHandler: Send + Sync { data: Vec>, temp_dir: &PathBuf, sign_options: &HashMap, - key_attributes: &HashMap + key_attributes: &HashMap, ) -> Result<(String, String)>; } diff --git a/src/client/load_balancer/dns.rs b/src/client/load_balancer/dns.rs index 3fcca98..5ad5524 100644 --- a/src/client/load_balancer/dns.rs +++ b/src/client/load_balancer/dns.rs @@ -14,33 +14,34 @@ * */ -use tonic::transport::{Channel, ClientTlsConfig}; use super::traits::DynamicLoadBalancer; +use tonic::transport::{Channel, ClientTlsConfig}; use crate::util::error::Result; -use tonic::transport::Endpoint; use async_trait::async_trait; +use tonic::transport::Endpoint; - -use dns_lookup::{lookup_host}; -use crate::util::error::Error::{DNSResolveError}; +use crate::util::error::Error::DNSResolveError; +use dns_lookup::lookup_host; pub struct DNSLoadBalancer { hostname: String, port: String, - client_config: Option + client_config: Option, } impl DNSLoadBalancer { - - pub fn new(hostname: String, port: String, client_config: Option) -> Result { + pub fn new( + hostname: String, + port: String, + client_config: Option, + ) -> Result { Ok(Self { hostname, port, - client_config + client_config, }) } - } #[async_trait] @@ -50,8 +51,8 @@ impl DynamicLoadBalancer for DNSLoadBalancer { match lookup_host(&self.hostname) { Ok(hosts) => { for ip in hosts.into_iter() { - let mut endpoint = Endpoint::from_shared( - format!("http://{}:{}", ip, self.port))?; + let mut endpoint = + Endpoint::from_shared(format!("http://{}:{}", ip, self.port))?; if let Some(tls_config) = self.client_config.clone() { endpoint = endpoint.tls_config(tls_config)?; } @@ -60,10 +61,7 @@ impl DynamicLoadBalancer for DNSLoadBalancer { } Ok(Channel::balance_list(endpoints.into_iter())) } - Err(_) => { - Err(DNSResolveError(self.hostname.clone())) - } + Err(_) => Err(DNSResolveError(self.hostname.clone())), } - } -} \ No newline at end of file +} diff --git a/src/client/load_balancer/factory.rs b/src/client/load_balancer/factory.rs index 79ed045..ad24c6a 100644 --- a/src/client/load_balancer/factory.rs +++ b/src/client/load_balancer/factory.rs @@ -14,67 +14,106 @@ * */ -use tonic::transport::{Channel, ClientTlsConfig, Identity}; -use std::collections::HashMap; -use config::Value; use crate::client::load_balancer::dns::DNSLoadBalancer; use crate::client::load_balancer::single::SingleLoadBalancer; use crate::client::load_balancer::traits::DynamicLoadBalancer; -use crate::util::error::{Error, Result}; use crate::util::error::Error::ConfigError; +use crate::util::error::{Error, Result}; use crate::util::key::file_exists; +use config::Value; +use std::collections::HashMap; +use tonic::transport::{Channel, ClientTlsConfig, Identity}; pub struct ChannelFactory { - lb: Box + lb: Box, } impl ChannelFactory { pub async fn new(config: &HashMap) -> Result { - let mut client_config :Option = None; - let tls_cert = config.get("tls_cert").unwrap_or( - &Value::new(Some(&String::new()), config::ValueKind::String(String::new()))).to_string(); - let tls_key = config.get("tls_key").unwrap_or( - &Value::new(Some(&String::new()), config::ValueKind::String(String::new()))).to_string(); - let server_address = config.get("server_address").unwrap_or( - &Value::new(Some(&String::new()), config::ValueKind::String(String::new()))).to_string(); - let server_port = config.get("server_port").unwrap_or( - &Value::new(Some(&String::new()), config::ValueKind::String(String::new()))).to_string(); + let mut client_config: Option = None; + let tls_cert = config + .get("tls_cert") + .unwrap_or(&Value::new( + Some(&String::new()), + config::ValueKind::String(String::new()), + )) + .to_string(); + let tls_key = config + .get("tls_key") + .unwrap_or(&Value::new( + Some(&String::new()), + config::ValueKind::String(String::new()), + )) + .to_string(); + let server_address = config + .get("server_address") + .unwrap_or(&Value::new( + Some(&String::new()), + config::ValueKind::String(String::new()), + )) + .to_string(); + let server_port = config + .get("server_port") + .unwrap_or(&Value::new( + Some(&String::new()), + config::ValueKind::String(String::new()), + )) + .to_string(); if server_address.is_empty() || server_port.is_empty() { - return Err(ConfigError(format!("server address: {} or port: {} not configured", server_address, server_port))); + return Err(ConfigError(format!( + "server address: {} or port: {} not configured", + server_address, server_port + ))); } - if tls_cert.is_empty() || tls_key.is_empty() - { + if tls_cert.is_empty() || tls_key.is_empty() { info!("tls client key and cert not configured, tls will be disabled"); } else { info!("tls client key and cert configured, tls will be enabled"); debug!("tls cert:{}, tls key:{}", tls_cert, tls_key); - if !file_exists(&tls_cert) || !file_exists(&tls_key){ - return Err(Error::FileFoundError(format!("client tls cert {} or key {} file not found", tls_key, tls_cert))); + if !file_exists(&tls_cert) || !file_exists(&tls_key) { + return Err(Error::FileFoundError(format!( + "client tls cert {} or key {} file not found", + tls_key, tls_cert + ))); } let identity = Identity::from_pem( tokio::fs::read(tls_cert).await?, - tokio::fs::read(tls_key).await?); - client_config = Some(ClientTlsConfig::new() - .identity(identity).domain_name(config.get("domain_name").unwrap_or(&Value::default()).to_string())); + tokio::fs::read(tls_key).await?, + ); + client_config = Some( + ClientTlsConfig::new().identity(identity).domain_name( + config + .get("domain_name") + .unwrap_or(&Value::default()) + .to_string(), + ), + ); } let lb_type = config.get("type").unwrap_or(&Value::default()).to_string(); if lb_type == "single" { return Ok(Self { lb: Box::new(SingleLoadBalancer::new( server_address, - server_port, client_config)?) - }) + server_port, + client_config, + )?), + }); } else if lb_type == "dns" { return Ok(Self { lb: Box::new(DNSLoadBalancer::new( server_address, - server_port, client_config)?) - }) + server_port, + client_config, + )?), + }); } - Err(ConfigError(format!("invalid load balancer type configuration: {}", lb_type))) + Err(ConfigError(format!( + "invalid load balancer type configuration: {}", + lb_type + ))) } pub fn get_channel(&self) -> Result { self.lb.get_transport_channel() } -} \ No newline at end of file +} diff --git a/src/client/load_balancer/mod.rs b/src/client/load_balancer/mod.rs index ef77cf5..6ca143e 100644 --- a/src/client/load_balancer/mod.rs +++ b/src/client/load_balancer/mod.rs @@ -1,4 +1,4 @@ -pub mod single; pub mod dns; +pub mod factory; +pub mod single; pub mod traits; -pub mod factory; \ No newline at end of file diff --git a/src/client/load_balancer/single.rs b/src/client/load_balancer/single.rs index 7e53941..4a54863 100644 --- a/src/client/load_balancer/single.rs +++ b/src/client/load_balancer/single.rs @@ -14,39 +14,40 @@ * */ -use tonic::transport::{Channel, ClientTlsConfig}; use super::traits::DynamicLoadBalancer; +use tonic::transport::{Channel, ClientTlsConfig}; use crate::util::error::Result; -use tonic::transport::Endpoint; use async_trait::async_trait; - +use tonic::transport::Endpoint; pub struct SingleLoadBalancer { server: String, port: String, - client_config: Option + client_config: Option, } impl SingleLoadBalancer { - pub fn new(server: String, port: String, client_config: Option) -> Result { + pub fn new( + server: String, + port: String, + client_config: Option, + ) -> Result { Ok(Self { server, port, - client_config + client_config, }) } - } #[async_trait] impl DynamicLoadBalancer for SingleLoadBalancer { fn get_transport_channel(&self) -> Result { - let mut endpoint = Endpoint::from_shared( - format!("http://{}:{}", self.server, self.port))?; + let mut endpoint = Endpoint::from_shared(format!("http://{}:{}", self.server, self.port))?; if let Some(tls_config) = self.client_config.clone() { endpoint = endpoint.tls_config(tls_config)? } Ok(Channel::balance_list(vec![endpoint].into_iter())) } -} \ No newline at end of file +} diff --git a/src/client/load_balancer/traits.rs b/src/client/load_balancer/traits.rs index 338dd0f..8bad17f 100644 --- a/src/client/load_balancer/traits.rs +++ b/src/client/load_balancer/traits.rs @@ -14,11 +14,11 @@ * */ -use tonic::transport::Channel; use crate::util::error::Result; use async_trait::async_trait; +use tonic::transport::Channel; #[async_trait] pub trait DynamicLoadBalancer { fn get_transport_channel(&self) -> Result; -} \ No newline at end of file +} diff --git a/src/client/mod.rs b/src/client/mod.rs index f19155d..88bc305 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,5 +1,5 @@ pub mod cmd; +pub mod file_handler; +pub mod load_balancer; pub mod sign_identity; pub mod worker; -pub mod file_handler; -pub mod load_balancer; \ No newline at end of file diff --git a/src/client/sign_identity.rs b/src/client/sign_identity.rs index 6002ae4..f01ba9c 100644 --- a/src/client/sign_identity.rs +++ b/src/client/sign_identity.rs @@ -16,12 +16,11 @@ use std::path::PathBuf; -use std::cell::{RefCell}; -use crate::util::error::{Result}; +use crate::util::error::Result; use crate::util::sign::{FileType, KeyType}; +use std::cell::RefCell; use std::collections::HashMap; - pub struct SignIdentity { //absolute file path pub file_path: PathBuf, @@ -35,7 +34,13 @@ pub struct SignIdentity { } impl SignIdentity { - pub(crate) fn new(file_type: FileType, file_path: PathBuf, key_type: KeyType, key_id: String, sign_options: HashMap) -> Self { + pub(crate) fn new( + file_type: FileType, + file_path: PathBuf, + key_type: KeyType, + key_id: String, + sign_options: HashMap, + ) -> Self { Self { file_type, file_path, diff --git a/src/client/worker/assembler.rs b/src/client/worker/assembler.rs index d2c39d3..b0ed45f 100644 --- a/src/client/worker/assembler.rs +++ b/src/client/worker/assembler.rs @@ -14,35 +14,30 @@ * */ +use crate::client::sign_identity::SignIdentity; use std::collections::HashMap; -use crate::client::{sign_identity::SignIdentity}; - -use crate::client::worker::traits::SignHandler; use crate::client::file_handler::traits::FileHandler; +use crate::client::worker::traits::SignHandler; +use crate::util::error::Error; use async_trait::async_trait; -use std::path::{Path, PathBuf}; use std::fs::copy; -use crate::util::error::Error; - +use std::path::{Path, PathBuf}; use std::fs; pub struct Assembler { temp_dir: PathBuf, - key_attributes: HashMap + key_attributes: HashMap, } - impl Assembler { - - pub fn new(temp_dir: String, key_attributes: HashMap) -> Self { + pub fn new(temp_dir: String, key_attributes: HashMap) -> Self { Self { temp_dir: PathBuf::from(temp_dir), - key_attributes + key_attributes, } } - } #[async_trait] @@ -51,16 +46,32 @@ impl SignHandler for Assembler { async fn process(&mut self, handler: Box, item: SignIdentity) -> SignIdentity { let signatures: Vec> = (*item.signature).borrow().clone(); let sign_options = item.sign_options.borrow().clone(); - match handler.assemble_data(&item.file_path, signatures, &self.temp_dir, &sign_options, &self.key_attributes).await { + match handler + .assemble_data( + &item.file_path, + signatures, + &self.temp_dir, + &sign_options, + &self.key_attributes, + ) + .await + { Ok(content) => { - debug!("successfully assemble file {}", item.file_path.as_path().display()); + debug!( + "successfully assemble file {}", + item.file_path.as_path().display() + ); let temp_file = Path::new(&content.0); match copy(temp_file, Path::new(&content.1)) { Ok(_) => { - debug!("successfully saved file {}", item.file_path.as_path().display()); + debug!( + "successfully saved file {}", + item.file_path.as_path().display() + ); } Err(err) => { - *item.error.borrow_mut() = Err(Error::AssembleFileError(format!("{:?}", err))); + *item.error.borrow_mut() = + Err(Error::AssembleFileError(format!("{:?}", err))); } } //remove temp file when finished @@ -72,4 +83,4 @@ impl SignHandler for Assembler { } item } -} \ No newline at end of file +} diff --git a/src/client/worker/key_fetcher.rs b/src/client/worker/key_fetcher.rs index 31e9376..ae7c1ae 100644 --- a/src/client/worker/key_fetcher.rs +++ b/src/client/worker/key_fetcher.rs @@ -14,17 +14,15 @@ * */ +use crate::util::error::{Error, Result}; use std::collections::HashMap; -use crate::util::error::{Result, Error}; pub mod signatrust { tonic::include_proto!("signatrust"); } +use self::signatrust::{signatrust_client::SignatrustClient, GetKeyInfoRequest}; use tonic::transport::Channel; -use self::signatrust::{ - signatrust_client::SignatrustClient, GetKeyInfoRequest -}; pub struct KeyFetcher { client: SignatrustClient, @@ -32,16 +30,19 @@ pub struct KeyFetcher { } impl KeyFetcher { - pub fn new(channel: Channel, token: Option) -> Self { Self { client: SignatrustClient::new(channel), - token + token, } } - pub async fn get_key_attributes(&mut self, key_name: &str, key_type: &str) -> Result> { - let key = GetKeyInfoRequest{ + pub async fn get_key_attributes( + &mut self, + key_name: &str, + key_type: &str, + ) -> Result> { + let key = GetKeyInfoRequest { key_type: key_type.to_string(), key_id: key_name.to_string(), token: self.token.clone(), @@ -55,9 +56,7 @@ impl KeyFetcher { Err(Error::RemoteSignError(format!("{:?}", data.error))) } } - Err(err) => { - Err(Error::RemoteSignError(format!("{:?}", err))) - } + Err(err) => Err(Error::RemoteSignError(format!("{:?}", err))), } } } diff --git a/src/client/worker/mod.rs b/src/client/worker/mod.rs index 437e916..e51b3ab 100644 --- a/src/client/worker/mod.rs +++ b/src/client/worker/mod.rs @@ -1,5 +1,5 @@ pub mod assembler; -pub mod splitter; +pub mod key_fetcher; pub mod signer; +pub mod splitter; pub mod traits; -pub mod key_fetcher; diff --git a/src/client/worker/signer.rs b/src/client/worker/signer.rs index 9635676..65b32a6 100644 --- a/src/client/worker/signer.rs +++ b/src/client/worker/signer.rs @@ -14,19 +14,17 @@ * */ -use crate::client::{sign_identity::SignIdentity}; -use crate::client::worker::traits::SignHandler; use crate::client::file_handler::traits::FileHandler; +use crate::client::sign_identity::SignIdentity; +use crate::client::worker::traits::SignHandler; use async_trait::async_trait; pub mod signatrust { tonic::include_proto!("signatrust"); } +use self::signatrust::{signatrust_client::SignatrustClient, SignStreamRequest}; use tonic::transport::Channel; -use self::signatrust::{ - signatrust_client::SignatrustClient, SignStreamRequest, -}; use crate::util::error::Error; use std::io::{Cursor, Read}; @@ -37,21 +35,23 @@ pub struct RemoteSigner { token: Option, } - impl RemoteSigner { - - pub fn new(channel: Channel, buffer_size: usize,token: Option) -> Self { + pub fn new(channel: Channel, buffer_size: usize, token: Option) -> Self { Self { client: SignatrustClient::new(channel), buffer_size, - token + token, } } } #[async_trait] impl SignHandler for RemoteSigner { - async fn process(&mut self, _handler: Box, item: SignIdentity) -> SignIdentity { + async fn process( + &mut self, + _handler: Box, + item: SignIdentity, + ) -> SignIdentity { let mut signed_content = Vec::new(); let read_data = item.raw_content.borrow().clone(); for sign_content in read_data.into_iter() { @@ -60,10 +60,10 @@ impl SignHandler for RemoteSigner { let mut cursor = Cursor::new(sign_content); while let Ok(length) = cursor.read(&mut buffer) { if length == 0 { - break + break; } let content = buffer[0..length].to_vec(); - sign_segments.push(SignStreamRequest{ + sign_segments.push(SignStreamRequest { data: content, options: item.sign_options.borrow().clone(), key_type: format!("{}", item.key_type), @@ -73,10 +73,12 @@ impl SignHandler for RemoteSigner { } if sign_segments.is_empty() { *item.error.borrow_mut() = Err(Error::FileContentEmpty); - return item + return item; } - let result = self.client.sign_stream( - tokio_stream::iter(sign_segments)).await; + let result = self + .client + .sign_stream(tokio_stream::iter(sign_segments)) + .await; match result { Ok(result) => { let data = result.into_inner(); @@ -91,11 +93,13 @@ impl SignHandler for RemoteSigner { } } } - debug!("successfully sign file {}", item.file_path.as_path().display()); + debug!( + "successfully sign file {}", + item.file_path.as_path().display() + ); *item.signature.borrow_mut() = signed_content; //clear out temporary value *item.raw_content.borrow_mut() = Vec::new(); item } } - diff --git a/src/client/worker/splitter.rs b/src/client/worker/splitter.rs index 0f74f92..190a014 100644 --- a/src/client/worker/splitter.rs +++ b/src/client/worker/splitter.rs @@ -14,26 +14,21 @@ * */ -use crate::client::{sign_identity::SignIdentity}; +use crate::client::sign_identity::SignIdentity; - -use crate::client::worker::traits::SignHandler; use crate::client::file_handler::traits::FileHandler; +use crate::client::worker::traits::SignHandler; +use crate::util::error; use async_trait::async_trait; use std::collections::HashMap; -use crate::util::error; pub struct Splitter { - key_attributes: HashMap + key_attributes: HashMap, } - impl Splitter { - pub fn new(key_attributes: HashMap) -> Self { - Self { - key_attributes - } + Self { key_attributes } } } @@ -41,11 +36,17 @@ impl Splitter { impl SignHandler for Splitter { async fn process(&mut self, handler: Box, item: SignIdentity) -> SignIdentity { let mut sign_options = item.sign_options.borrow().clone(); - match handler.split_data(&item.file_path, &mut sign_options, &self.key_attributes).await { + match handler + .split_data(&item.file_path, &mut sign_options, &self.key_attributes) + .await + { Ok(content) => { *item.raw_content.borrow_mut() = content; *item.sign_options.borrow_mut() = sign_options; - debug!("successfully split file {}", item.file_path.as_path().display()); + debug!( + "successfully split file {}", + item.file_path.as_path().display() + ); } Err(err) => { *item.error.borrow_mut() = Err(error::Error::SplitFileError(format!("{:?}", err))) @@ -54,4 +55,3 @@ impl SignHandler for Splitter { item } } - diff --git a/src/client/worker/traits.rs b/src/client/worker/traits.rs index 64f97d7..8153fd2 100644 --- a/src/client/worker/traits.rs +++ b/src/client/worker/traits.rs @@ -14,11 +14,11 @@ * */ -use async_trait::async_trait; -use crate::client::sign_identity::{SignIdentity}; -use async_channel::{Sender}; use crate::client::file_handler::factory::FileHandlerFactory; use crate::client::file_handler::traits::FileHandler; +use crate::client::sign_identity::SignIdentity; +use async_channel::Sender; +use async_trait::async_trait; #[async_trait] pub trait SignHandler { @@ -37,4 +37,4 @@ pub trait SignHandler { } //NOTE: instead of raise out error for specific sign object out of method, we need record error inside of the SignIdentity object. async fn process(&mut self, handler: Box, item: SignIdentity) -> SignIdentity; -} \ No newline at end of file +} diff --git a/src/client_entrypoint.rs b/src/client_entrypoint.rs index 0ce7f47..34f315c 100644 --- a/src/client_entrypoint.rs +++ b/src/client_entrypoint.rs @@ -15,18 +15,18 @@ */ #![allow(dead_code)] -use std::env; -use crate::util::error::{Result, Error}; -use clap::{Parser, Subcommand}; use crate::client::cmd::add; -use config::{Config, File}; -use std::sync::{Arc, atomic::AtomicBool, RwLock}; use crate::client::cmd::traits::SignCommand; +use crate::util::error::{Error, Result}; +use clap::{Parser, Subcommand}; +use config::{Config, File}; +use std::env; +use std::sync::{atomic::AtomicBool, Arc, RwLock}; -mod infra; -mod util; mod client; mod domain; +mod infra; +mod util; #[macro_use] extern crate log; @@ -41,7 +41,7 @@ extern crate lazy_static; pub struct App { #[arg(short, long)] #[arg( - help = "path of configuration file, './client.toml' relative to working directory be used in default" + help = "path of configuration file, './client.toml' relative to working directory be used in default" )] config: Option, #[command(subcommand)] @@ -58,19 +58,29 @@ fn main() -> Result<()> { //prepare config and logger env_logger::init(); let app = App::parse(); - let path = app.config.unwrap_or( - format!("{}/{}", env::current_dir().expect("current dir not found").display(), "client.toml")); - let client = Config::builder().add_source(File::with_name(path.as_str())).build().expect("load client configuration file"); + let path = app.config.unwrap_or(format!( + "{}/{}", + env::current_dir().expect("current dir not found").display(), + "client.toml" + )); + let client = Config::builder() + .add_source(File::with_name(path.as_str())) + .build() + .expect("load client configuration file"); let signal = Arc::new(AtomicBool::new(false)); - signal_hook::flag::register(signal_hook::consts::SIGTERM, Arc::clone(&signal)).expect("failed to register sigterm signal"); - signal_hook::flag::register(signal_hook::consts::SIGINT, Arc::clone(&signal)).expect("failed to register sigint signal"); + signal_hook::flag::register(signal_hook::consts::SIGTERM, Arc::clone(&signal)) + .expect("failed to register sigterm signal"); + signal_hook::flag::register(signal_hook::consts::SIGINT, Arc::clone(&signal)) + .expect("failed to register sigint signal"); //construct handler let command = match app.command { - Some(Commands::Add(add_command)) => { - Some(add::CommandAddHandler::new(signal, Arc::new(RwLock::new(client)), add_command)?) - } - None => {None} + Some(Commands::Add(add_command)) => Some(add::CommandAddHandler::new( + signal, + Arc::new(RwLock::new(client)), + add_command, + )?), + None => None, }; //handler and quit if let Some(handler) = command { @@ -81,7 +91,7 @@ fn main() -> Result<()> { if let Err(err) = handler.handle() { error!("failed to handle command: {}", err); - return Err(Error::PartialSuccessError) + return Err(Error::PartialSuccessError); } } Ok(()) diff --git a/src/control_admin_entrypoint.rs b/src/control_admin_entrypoint.rs index 43264dc..e1e12a1 100644 --- a/src/control_admin_entrypoint.rs +++ b/src/control_admin_entrypoint.rs @@ -16,28 +16,26 @@ #![allow(dead_code)] -use std::collections::HashMap; -use std::env; -use validator::Validate; -use chrono::{Duration, Utc}; -use crate::util::error::{Result}; +use crate::domain::datakey::entity::DataKey; use crate::domain::datakey::entity::{KeyType as EntityKeyTpe, KeyType}; -use clap::{Parser, Subcommand}; -use clap::{Args}; -use tokio_util::sync::CancellationToken; -use crate::domain::datakey::entity::{DataKey}; use crate::domain::user::entity::User; -use crate::presentation::handler::control::model::datakey::dto::{CreateDataKeyDTO}; +use crate::presentation::handler::control::model::datakey::dto::CreateDataKeyDTO; use crate::presentation::handler::control::model::user::dto::UserIdentity; +use crate::util::error::Result; +use chrono::{Duration, Utc}; +use clap::Args; +use clap::{Parser, Subcommand}; +use std::collections::HashMap; +use std::env; use std::str::FromStr; +use tokio_util::sync::CancellationToken; +use validator::Validate; - -mod util; -mod infra; -mod domain; mod application; +mod domain; +mod infra; mod presentation; - +mod util; #[macro_use] extern crate log; @@ -50,7 +48,7 @@ extern crate log; pub struct App { #[arg(short, long)] #[arg( - help = "path of configuration file, './client.toml' relative to working directory be used in default" + help = "path of configuration file, './client.toml' relative to working directory be used in default" )] config: Option, #[command(subcommand)] @@ -133,20 +131,50 @@ fn generate_keys_parameters(command: &CommandGenerateKeys) -> HashMap Result<()> { //prepare config and logger env_logger::init(); let app = App::parse(); - let path = app.config.unwrap_or(format!("{}/{}", env::current_dir().expect("current dir not found").display(), - "config/server.toml")); + let path = app.config.unwrap_or(format!( + "{}/{}", + env::current_dir().expect("current dir not found").display(), + "config/server.toml" + )); let server_config = util::config::ServerConfig::new(path); //cancel token will never been used/canceled here cause it's only used for background threads in control server instance. - let control_server = presentation::server::control_server::ControlServer::new(server_config.config, CancellationToken::new()).await?; + let control_server = presentation::server::control_server::ControlServer::new( + server_config.config, + CancellationToken::new(), + ) + .await?; //handle commands match app.command { Some(Commands::CreateAdmin(create_admin)) => { - let token = control_server.create_user_token(User::new(create_admin.email.clone())?).await?; + let token = control_server + .create_user_token(User::new(create_admin.email.clone())?) + .await?; info!("[Result]: Administrator {} has been successfully created with token {} will expire {}", &create_admin.email, &token.token, &token.expire_at) } Some(Commands::GenerateKeys(generate_keys)) => { - let user = control_server.get_user_by_email(&generate_keys.email).await?; + let user = control_server + .get_user_by_email(&generate_keys.email) + .await?; let now = Utc::now(); let mut key = CreateDataKeyDTO { @@ -192,8 +231,16 @@ async fn main() -> Result<()> { } key.validate()?; - let keys = control_server.create_keys(&mut DataKey::create_from(key, UserIdentity::from_user(user.clone()))?, UserIdentity::from_user(user)).await?; - info!("[Result]: Keys {} type {} has been successfully generated", &keys.name, &generate_keys.key_type) + let keys = control_server + .create_keys( + &mut DataKey::create_from(key, UserIdentity::from_user(user.clone()))?, + UserIdentity::from_user(user), + ) + .await?; + info!( + "[Result]: Keys {} type {} has been successfully generated", + &keys.name, &generate_keys.key_type + ) } None => {} }; diff --git a/src/control_server_entrypoint.rs b/src/control_server_entrypoint.rs index dc2f89b..cc1d213 100644 --- a/src/control_server_entrypoint.rs +++ b/src/control_server_entrypoint.rs @@ -20,17 +20,17 @@ use clap::Parser; use config::Config; use std::env; use std::sync::{Arc, RwLock}; -use tokio_util::sync::CancellationToken; use tokio::{ select, signal::unix::{signal, SignalKind}, }; +use tokio_util::sync::CancellationToken; -mod infra; +mod application; mod domain; +mod infra; mod presentation; mod util; -mod application; #[macro_use] extern crate log; @@ -45,7 +45,7 @@ extern crate lazy_static; pub struct App { #[arg(short, long)] #[arg( - help = "path of configuration file, 'config/server.toml' relative to working directory be used in default" + help = "path of configuration file, 'config/server.toml' relative to working directory be used in default" )] config: Option, } @@ -85,7 +85,11 @@ async fn main() -> Result<()> { //prepare config and logger env_logger::init(); //control server starts - let control_server = presentation::server::control_server::ControlServer::new(SERVERCONFIG.clone(), CANCEL_TOKEN.clone()).await?; + let control_server = presentation::server::control_server::ControlServer::new( + SERVERCONFIG.clone(), + CANCEL_TOKEN.clone(), + ) + .await?; control_server.run().await?; Ok(()) } diff --git a/src/data_server_entrypoint.rs b/src/data_server_entrypoint.rs index a00e8a2..d2f656e 100644 --- a/src/data_server_entrypoint.rs +++ b/src/data_server_entrypoint.rs @@ -20,19 +20,19 @@ use clap::Parser; use config::Config; use std::env; use std::sync::{Arc, RwLock}; -use tokio_util::sync::CancellationToken; use tokio::{ select, signal::unix::{signal, SignalKind}, }; +use tokio_util::sync::CancellationToken; use crate::presentation::server::data_server::DataServer; -mod infra; +mod application; mod domain; +mod infra; mod presentation; mod util; -mod application; #[macro_use] extern crate log; @@ -87,7 +87,8 @@ async fn main() -> Result<()> { //prepare config and logger env_logger::init(); //data server starts - let data_server: DataServer = DataServer::new(SERVERCONFIG.clone(), CANCEL_TOKEN.clone()).await?; + let data_server: DataServer = + DataServer::new(SERVERCONFIG.clone(), CANCEL_TOKEN.clone()).await?; data_server.run().await?; Ok(()) } diff --git a/src/domain/clusterkey/entity.rs b/src/domain/clusterkey/entity.rs index 564735b..2fdbd28 100644 --- a/src/domain/clusterkey/entity.rs +++ b/src/domain/clusterkey/entity.rs @@ -15,8 +15,8 @@ */ use crate::util::{error::Result, key}; -use secstr::SecVec; use chrono::{DateTime, Utc}; +use secstr::SecVec; use std::fmt::{Display, Formatter}; use std::vec::Vec; @@ -77,7 +77,6 @@ pub struct SecClusterKey { } impl Default for SecClusterKey { - fn default() -> Self { SecClusterKey { id: 0, @@ -91,7 +90,9 @@ impl Default for SecClusterKey { impl SecClusterKey { pub async fn load(cluster_key: ClusterKey, kms_provider: &Box) -> Result - where K: KMSProvider + ?Sized { + where + K: KMSProvider + ?Sized, + { Ok(Self { id: cluster_key.id, data: SecVec::new(key::decode_hex_string_to_u8( @@ -119,8 +120,8 @@ impl Display for SecClusterKey { #[cfg(test)] mod test { use super::*; - use std::collections::HashMap; use crate::infra::kms::dummy::DummyKMS; + use std::collections::HashMap; fn get_dummy_kms_provider() -> Box { Box::new(DummyKMS::new(&HashMap::new()).unwrap()) @@ -129,12 +130,18 @@ mod test { #[tokio::test] async fn test_sec_cluster_key_load_and_display() { let kms_provider = get_dummy_kms_provider(); - let content = vec![1,2,3,4]; + let content = vec![1, 2, 3, 4]; let hexed_content = key::encode_u8_to_hex_string(&content).as_bytes().to_vec(); - let cluster_key = ClusterKey::new(hexed_content, "FAKE_ALGORITHM".to_string()).expect("create cluster key failed"); - let sec_cluster_key = SecClusterKey::load(cluster_key, &kms_provider).await.expect("load cluster key failed"); + let cluster_key = ClusterKey::new(hexed_content, "FAKE_ALGORITHM".to_string()) + .expect("create cluster key failed"); + let sec_cluster_key = SecClusterKey::load(cluster_key, &kms_provider) + .await + .expect("load cluster key failed"); assert_eq!(sec_cluster_key.data.unsecure(), content); - assert_eq!(true, format!("{}", sec_cluster_key).contains("FAKE_ALGORITHM")); + assert_eq!( + true, + format!("{}", sec_cluster_key).contains("FAKE_ALGORITHM") + ); } #[test] @@ -157,11 +164,11 @@ mod test { #[tokio::test] async fn test_cluster_key_new_and_display() { - let content = vec![1,2,3,4]; + let content = vec![1, 2, 3, 4]; let hexed_content = key::encode_u8_to_hex_string(&content).as_bytes().to_vec(); - let cluster_key = ClusterKey::new(hexed_content.clone(), "FAKE_ALGORITHM".to_string()).expect("create cluster key failed"); + let cluster_key = ClusterKey::new(hexed_content.clone(), "FAKE_ALGORITHM".to_string()) + .expect("create cluster key failed"); assert_eq!(cluster_key.data, hexed_content); assert_eq!(true, format!("{}", cluster_key).contains("FAKE_ALGORITHM")); } } - diff --git a/src/domain/datakey/entity.rs b/src/domain/datakey/entity.rs index fcbcbd8..7d46115 100644 --- a/src/domain/datakey/entity.rs +++ b/src/domain/datakey/entity.rs @@ -28,8 +28,6 @@ use crate::domain::encryption_engine::EncryptionEngine; pub const INFRA_CONFIG_DOMAIN_NAME: &str = "domain_name"; - - #[derive(Debug, Clone, Default, PartialEq)] pub enum KeyState { Enabled, @@ -38,7 +36,7 @@ pub enum KeyState { PendingRevoke, Revoked, PendingDelete, - Deleted + Deleted, } impl FromStr for KeyState { @@ -52,7 +50,10 @@ impl FromStr for KeyState { "revoked" => Ok(KeyState::Revoked), "pending_delete" => Ok(KeyState::PendingDelete), "deleted" => Ok(KeyState::Deleted), - _ => Err(Error::UnsupportedTypeError(format!("unsupported data key state {}", s))), + _ => Err(Error::UnsupportedTypeError(format!( + "unsupported data key state {}", + s + ))), } } } @@ -97,7 +98,10 @@ impl FromStr for KeyAction { "issue_cert" => Ok(KeyAction::IssueCert), "sign" => Ok(KeyAction::Sign), "read" => Ok(KeyAction::Read), - _ => Err(Error::UnsupportedTypeError(format!("unsupported data key action {}", s))), + _ => Err(Error::UnsupportedTypeError(format!( + "unsupported data key action {}", + s + ))), } } } @@ -138,7 +142,10 @@ impl FromStr for KeyType { "x509ca" => Ok(KeyType::X509CA), "x509ica" => Ok(KeyType::X509ICA), "x509ee" => Ok(KeyType::X509EE), - _ => Err(Error::UnsupportedTypeError(format!("unsupported data key type {}", s))), + _ => Err(Error::UnsupportedTypeError(format!( + "unsupported data key type {}", + s + ))), } } } @@ -164,7 +171,12 @@ pub struct X509CRL { } impl X509CRL { - pub fn new(ca_id: i32, data: Vec, create_at: DateTime, update_at: DateTime) -> Self { + pub fn new( + ca_id: i32, + data: Vec, + create_at: DateTime, + update_at: DateTime, + ) -> Self { X509CRL { id: 0, ca_id, @@ -202,7 +214,10 @@ impl FromStr for X509RevokeReason { "certificate_hold" => Ok(X509RevokeReason::CertificateHold), "privilege_withdrawn" => Ok(X509RevokeReason::PrivilegeWithdrawn), "aa_compromise" => Ok(X509RevokeReason::AACompromise), - _ => Err(Error::UnsupportedTypeError(format!("unsupported x509 revoke reason {}", s))), + _ => Err(Error::UnsupportedTypeError(format!( + "unsupported x509 revoke reason {}", + s + ))), } } } @@ -239,7 +254,7 @@ pub struct RevokedKey { pub ca_id: i32, pub reason: X509RevokeReason, pub create_at: DateTime, - pub serial_number: Option + pub serial_number: Option, } #[derive(Debug, Clone, PartialEq)] @@ -263,7 +278,7 @@ pub struct DataKey { pub user_email: Option, pub request_delete_users: Option, pub request_revoke_users: Option, - pub parent_key: Option + pub parent_key: Option, } #[derive(Debug, Clone)] pub struct PagedMeta { @@ -272,7 +287,7 @@ pub struct PagedMeta { #[derive(Debug, Clone)] pub struct PagedDatakey { pub data: Vec, - pub meta: PagedMeta + pub meta: PagedMeta, } #[derive(Debug, Clone)] @@ -301,11 +316,7 @@ impl Identity for DataKey { fn get_identity(&self) -> String { format!( "", - self.id, - self.name, - self.user, - self.key_type, - self.fingerprint + self.id, self.name, self.user, self.key_type, self.fingerprint ) } } @@ -316,7 +327,7 @@ pub struct SecParentDateKey { pub private_key: SecVec, pub public_key: SecVec, pub certificate: SecVec, - pub attributes: HashMap + pub attributes: HashMap, } pub struct SecDataKey { @@ -326,11 +337,14 @@ pub struct SecDataKey { pub certificate: SecVec, pub identity: String, pub attributes: HashMap, - pub parent: Option + pub parent: Option, } impl SecDataKey { - pub async fn load(data_key: &DataKey, engine: &Box) -> Result { + pub async fn load( + data_key: &DataKey, + engine: &Box, + ) -> Result { let mut sec_datakey = Self { name: data_key.name.clone(), private_key: SecVec::new(engine.decode(data_key.private_key.clone()).await?), @@ -366,22 +380,23 @@ pub struct DataKeyContent { pub enum Visibility { #[default] Public, - Private + Private, } impl Visibility { pub fn from_parameter(s: Option) -> Result { match s { - None => { - Ok(Visibility::Public) - } + None => Ok(Visibility::Public), Some(value) => { if value == "public" { return Ok(Visibility::Public); } else if value == "private" { return Ok(Visibility::Private); } - Err(Error::UnsupportedTypeError(format!("unsupported data key visibility {}", value))) + Err(Error::UnsupportedTypeError(format!( + "unsupported data key visibility {}", + value + ))) } } } @@ -394,7 +409,10 @@ impl FromStr for Visibility { match s { "public" => Ok(Visibility::Public), "private" => Ok(Visibility::Private), - _ => Err(Error::UnsupportedTypeError(format!("unsupported data key visibility {}", s))), + _ => Err(Error::UnsupportedTypeError(format!( + "unsupported data key visibility {}", + s + ))), } } } diff --git a/src/domain/datakey/mod.rs b/src/domain/datakey/mod.rs index 51eb69a..81d143c 100644 --- a/src/domain/datakey/mod.rs +++ b/src/domain/datakey/mod.rs @@ -1,4 +1,4 @@ pub mod entity; +pub mod plugins; pub mod repository; pub mod traits; -pub mod plugins; diff --git a/src/domain/datakey/plugins/mod.rs b/src/domain/datakey/plugins/mod.rs index 3cdf0b0..3e23572 100644 --- a/src/domain/datakey/plugins/mod.rs +++ b/src/domain/datakey/plugins/mod.rs @@ -13,4 +13,3 @@ */ pub mod openpgp; pub mod x509; - diff --git a/src/domain/datakey/plugins/openpgp.rs b/src/domain/datakey/plugins/openpgp.rs index c5cc5f4..07abe93 100644 --- a/src/domain/datakey/plugins/openpgp.rs +++ b/src/domain/datakey/plugins/openpgp.rs @@ -12,14 +12,14 @@ * // See the Mulan PSL v2 for more details. */ -use std::str::FromStr; use crate::util::error::{Error, Result}; -use std::fmt::{Display, Formatter}; -use std::fmt; -use pgp::composed::{KeyType}; -use pgp::crypto::{hash::HashAlgorithm}; -use enum_iterator::{Sequence}; +use enum_iterator::Sequence; +use pgp::composed::KeyType; +use pgp::crypto::hash::HashAlgorithm; use serde::Deserialize; +use std::fmt; +use std::fmt::{Display, Formatter}; +use std::str::FromStr; pub const PGP_VALID_KEY_SIZE: [&str; 3] = ["2048", "3072", "4096"]; @@ -38,7 +38,10 @@ impl FromStr for OpenPGPKeyType { match s { "rsa" => Ok(OpenPGPKeyType::Rsa), "eddsa" => Ok(OpenPGPKeyType::Eddsa), - _ => Err(Error::UnsupportedTypeError(format!("unsupported openpgp key state {}", s))), + _ => Err(Error::UnsupportedTypeError(format!( + "unsupported openpgp key state {}", + s + ))), } } } @@ -56,14 +59,14 @@ impl OpenPGPKeyType { //key length defaults to 2048 pub fn get_real_key_type(&self, key_length: Option) -> KeyType { match self { - OpenPGPKeyType::Rsa => { - if let Some(length) = key_length { - KeyType::Rsa(length.parse().unwrap()) - } else { - KeyType::Rsa(2048) - } - }, - OpenPGPKeyType::Eddsa => KeyType::EdDSA + OpenPGPKeyType::Rsa => { + if let Some(length) = key_length { + KeyType::Rsa(length.parse().unwrap()) + } else { + KeyType::Rsa(2048) + } + } + OpenPGPKeyType::Eddsa => KeyType::EdDSA, } } } @@ -120,7 +123,10 @@ impl FromStr for OpenPGPDigestAlgorithm { "sha2_512" => Ok(OpenPGPDigestAlgorithm::SHA2_512), "sha3_256" => Ok(OpenPGPDigestAlgorithm::SHA3_256), "sha3_512" => Ok(OpenPGPDigestAlgorithm::SHA3_512), - _ => Err(Error::UnsupportedTypeError(format!("unsupported openpgp digest algorithm {}", s))), + _ => Err(Error::UnsupportedTypeError(format!( + "unsupported openpgp digest algorithm {}", + s + ))), } } } @@ -130,7 +136,7 @@ impl OpenPGPDigestAlgorithm { match self { OpenPGPDigestAlgorithm::None => HashAlgorithm::None, OpenPGPDigestAlgorithm::MD5 => HashAlgorithm::MD5, - OpenPGPDigestAlgorithm::SHA1=> HashAlgorithm::SHA1, + OpenPGPDigestAlgorithm::SHA1 => HashAlgorithm::SHA1, OpenPGPDigestAlgorithm::SHA2_224 => HashAlgorithm::SHA2_224, OpenPGPDigestAlgorithm::SHA2_256 => HashAlgorithm::SHA2_256, OpenPGPDigestAlgorithm::SHA2_384 => HashAlgorithm::SHA2_384, diff --git a/src/domain/datakey/plugins/x509.rs b/src/domain/datakey/plugins/x509.rs index 4b2fe5c..da4c8e1 100644 --- a/src/domain/datakey/plugins/x509.rs +++ b/src/domain/datakey/plugins/x509.rs @@ -11,16 +11,16 @@ * // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. * // See the Mulan PSL v2 for more details. */ -use std::str::FromStr; use crate::util::error::{Error, Result}; -use std::fmt::{Display, Formatter}; -use std::fmt; -use openssl::pkey::{PKey, Private}; -use openssl::rsa::Rsa; +use enum_iterator::Sequence; use openssl::dsa::Dsa; use openssl::hash::MessageDigest; -use enum_iterator::{Sequence}; +use openssl::pkey::{PKey, Private}; +use openssl::rsa::Rsa; use serde::Deserialize; +use std::fmt; +use std::fmt::{Display, Formatter}; +use std::str::FromStr; pub const X509_VALID_KEY_SIZE: [&str; 3] = ["2048", "3072", "4096"]; @@ -39,7 +39,10 @@ impl FromStr for X509KeyType { match s { "rsa" => Ok(X509KeyType::Rsa), "dsa" => Ok(X509KeyType::Dsa), - _ => Err(Error::UnsupportedTypeError(format!("unsupported x509 key type {}", s))), + _ => Err(Error::UnsupportedTypeError(format!( + "unsupported x509 key type {}", + s + ))), } } } @@ -102,7 +105,10 @@ impl FromStr for X509DigestAlgorithm { "sha2_256" => Ok(X509DigestAlgorithm::SHA2_256), "sha2_384" => Ok(X509DigestAlgorithm::SHA2_384), "sha2_512" => Ok(X509DigestAlgorithm::SHA2_512), - _ => Err(Error::UnsupportedTypeError(format!("unsupported x509 digest algorithm {}", s))), + _ => Err(Error::UnsupportedTypeError(format!( + "unsupported x509 digest algorithm {}", + s + ))), } } } @@ -115,7 +121,7 @@ impl X509DigestAlgorithm { X509DigestAlgorithm::SHA2_224 => MessageDigest::sha224(), X509DigestAlgorithm::SHA2_256 => MessageDigest::sha256(), X509DigestAlgorithm::SHA2_384 => MessageDigest::sha384(), - X509DigestAlgorithm::SHA2_512 => MessageDigest::sha512() + X509DigestAlgorithm::SHA2_512 => MessageDigest::sha512(), } } } diff --git a/src/domain/datakey/repository.rs b/src/domain/datakey/repository.rs index cabbc21..a2694b0 100644 --- a/src/domain/datakey/repository.rs +++ b/src/domain/datakey/repository.rs @@ -15,22 +15,51 @@ */ use super::entity::DataKey; +use crate::domain::datakey::entity::{ + DatakeyPaginationQuery, KeyState, PagedDatakey, RevokedKey, X509RevokeReason, X509CRL, +}; use crate::util::error::Result; use async_trait::async_trait; use chrono::Duration; -use crate::domain::datakey::entity::{DatakeyPaginationQuery, KeyState, PagedDatakey, RevokedKey, X509CRL, X509RevokeReason}; #[async_trait] pub trait Repository: Send + Sync { async fn create(&self, data_key: DataKey) -> Result; async fn delete(&self, id: i32) -> Result<()>; - async fn get_all_keys(&self, user_id: i32, query: DatakeyPaginationQuery) -> Result; - async fn get_by_id_or_name(&self, id: Option, name: Option, raw_datakey: bool) -> Result; + async fn get_all_keys( + &self, + user_id: i32, + query: DatakeyPaginationQuery, + ) -> Result; + async fn get_by_id_or_name( + &self, + id: Option, + name: Option, + raw_datakey: bool, + ) -> Result; async fn update_state(&self, id: i32, state: KeyState) -> Result<()>; async fn update_key_data(&self, data_key: DataKey) -> Result<()>; - async fn get_enabled_key_by_type_and_name_with_parent_key(&self, key_type: String, name: String) -> Result; - async fn request_delete_key(&self, user_id: i32, user_email: String, id: i32, public_key: bool) -> Result<()>; - async fn request_revoke_key(&self, user_id: i32, user_email: String, id: i32, parent_id: i32, reason: X509RevokeReason, public_key: bool) -> Result<()>; + async fn get_enabled_key_by_type_and_name_with_parent_key( + &self, + key_type: String, + name: String, + ) -> Result; + async fn request_delete_key( + &self, + user_id: i32, + user_email: String, + id: i32, + public_key: bool, + ) -> Result<()>; + async fn request_revoke_key( + &self, + user_id: i32, + user_email: String, + id: i32, + parent_id: i32, + reason: X509RevokeReason, + public_key: bool, + ) -> Result<()>; async fn cancel_delete_key(&self, user_id: i32, id: i32) -> Result<()>; async fn cancel_revoke_key(&self, user_id: i32, id: i32, parent_id: i32) -> Result<()>; //crl related methods diff --git a/src/domain/encryption_engine.rs b/src/domain/encryption_engine.rs index 1912629..985d5f9 100644 --- a/src/domain/encryption_engine.rs +++ b/src/domain/encryption_engine.rs @@ -14,8 +14,8 @@ * */ -use async_trait::async_trait; use crate::util::error::Result; +use async_trait::async_trait; #[async_trait] pub trait EncryptionEngine: Send + Sync { @@ -23,4 +23,4 @@ pub trait EncryptionEngine: Send + Sync { async fn rotate_key(&mut self) -> Result; async fn encode(&self, content: Vec) -> Result>; async fn decode(&self, content: Vec) -> Result>; -} \ No newline at end of file +} diff --git a/src/domain/encryptor.rs b/src/domain/encryptor.rs index 933ae1a..41cd7d4 100644 --- a/src/domain/encryptor.rs +++ b/src/domain/encryptor.rs @@ -57,9 +57,9 @@ mod test { #[test] fn test_algorithm_from_string_and_display() { - let _ = Algorithm::from_str("invalid_algorithm").expect_err("algorithm from invalid string should fail"); + let _ = Algorithm::from_str("invalid_algorithm") + .expect_err("algorithm from invalid string should fail"); let algorithm = Algorithm::from_str("aes256gsm").expect("algorithm from string failed"); assert_eq!(format!("{}", algorithm), "Aes256GSM"); } } - diff --git a/src/domain/kms_provider.rs b/src/domain/kms_provider.rs index e120bd8..9113437 100644 --- a/src/domain/kms_provider.rs +++ b/src/domain/kms_provider.rs @@ -36,7 +36,6 @@ impl FromStr for KMSType { } } - #[async_trait] pub trait KMSProvider: Send + Sync { async fn encode(&self, content: String) -> Result; @@ -49,10 +48,11 @@ mod test { #[test] fn test_kms_type_from_string_and_display() { - let _ = KMSType::from_str("invalid_type").expect_err("kms type from invalid string should fail"); + let _ = KMSType::from_str("invalid_type") + .expect_err("kms type from invalid string should fail"); let kms_type1 = KMSType::from_str("huaweicloud").expect("kms type from string failed"); assert_eq!(kms_type1, KMSType::HuaweiCloud); let kms_type2 = KMSType::from_str("dummy").expect("kms type from string failed"); assert_eq!(kms_type2, KMSType::Dummy); } -} \ No newline at end of file +} diff --git a/src/domain/mod.rs b/src/domain/mod.rs index 4d381d3..ac48d89 100644 --- a/src/domain/mod.rs +++ b/src/domain/mod.rs @@ -1,9 +1,9 @@ pub mod clusterkey; pub mod datakey; -pub mod user; -pub mod token; pub mod encryption_engine; pub mod encryptor; pub mod kms_provider; pub mod sign_plugin; pub mod sign_service; +pub mod token; +pub mod user; diff --git a/src/domain/sign_plugin.rs b/src/domain/sign_plugin.rs index 6ad2557..dd2c901 100644 --- a/src/domain/sign_plugin.rs +++ b/src/domain/sign_plugin.rs @@ -14,26 +14,35 @@ * */ +use crate::domain::datakey::entity::{DataKey, DataKeyContent, KeyType, RevokedKey, SecDataKey}; use crate::util::error::Result; -use std::collections::HashMap; use chrono::{DateTime, Utc}; -use crate::domain::datakey::entity::{DataKey, DataKeyContent, KeyType, RevokedKey, SecDataKey}; +use std::collections::HashMap; pub trait SignPlugins: Send + Sync { fn new(db: SecDataKey) -> Result - where - Self: Sized; + where + Self: Sized; fn validate_and_update(key: &mut DataKey) -> Result<()> - where - Self: Sized; + where + Self: Sized; fn parse_attributes( private_key: Option>, public_key: Option>, certificate: Option>, ) -> HashMap - where - Self: Sized; - fn generate_keys(&self, key_type: &KeyType, infra_configs: &HashMap) -> Result; + where + Self: Sized; + fn generate_keys( + &self, + key_type: &KeyType, + infra_configs: &HashMap, + ) -> Result; fn sign(&self, content: Vec, options: HashMap) -> Result>; - fn generate_crl_content(&self, revoked_keys: Vec, last_update: DateTime, next_update: DateTime) -> Result>; + fn generate_crl_content( + &self, + revoked_keys: Vec, + last_update: DateTime, + next_update: DateTime, + ) -> Result>; } diff --git a/src/domain/sign_service.rs b/src/domain/sign_service.rs index b4e53ab..ab7ed34 100644 --- a/src/domain/sign_service.rs +++ b/src/domain/sign_service.rs @@ -14,7 +14,7 @@ * */ -use crate::util::error::{Result, Error}; +use crate::util::error::{Error, Result}; use std::collections::HashMap; use std::str::FromStr; @@ -33,30 +33,45 @@ impl FromStr for SignBackendType { fn from_str(s: &str) -> Result { match s { "memory" => Ok(SignBackendType::Memory), - _ => Err(Error::UnsupportedTypeError(format!("{} sign backend type", s))), + _ => Err(Error::UnsupportedTypeError(format!( + "{} sign backend type", + s + ))), } } } #[async_trait] -pub trait SignBackend: Send + Sync{ +pub trait SignBackend: Send + Sync { async fn validate_and_update(&self, data_key: &mut DataKey) -> Result<()>; async fn generate_keys(&self, data_key: &mut DataKey) -> Result<()>; async fn rotate_key(&mut self) -> Result; - async fn sign(&self, data_key: &DataKey, content: Vec, options: HashMap) -> Result>; + async fn sign( + &self, + data_key: &DataKey, + content: Vec, + options: HashMap, + ) -> Result>; async fn decode_public_keys(&self, data_key: &mut DataKey) -> Result<()>; - async fn generate_crl_content(&self, data_key: &DataKey, revoked_keys: Vec, last_update: DateTime, next_update: DateTime) -> Result>; + async fn generate_crl_content( + &self, + data_key: &DataKey, + revoked_keys: Vec, + last_update: DateTime, + next_update: DateTime, + ) -> Result>; } - #[cfg(test)] mod test { use super::*; #[test] fn test_sign_backend_type_from_string_and_display() { - let _ = SignBackendType::from_str("invalid_type").expect_err("sign backend type from invalid string should fail"); - let sign_backend_type = SignBackendType::from_str("memory").expect("sign backend type from string failed"); + let _ = SignBackendType::from_str("invalid_type") + .expect_err("sign backend type from invalid string should fail"); + let sign_backend_type = + SignBackendType::from_str("memory").expect("sign backend type from string failed"); assert_eq!(sign_backend_type, SignBackendType::Memory); } -} \ No newline at end of file +} diff --git a/src/domain/token/entity.rs b/src/domain/token/entity.rs index 5985f5a..d2d6c74 100644 --- a/src/domain/token/entity.rs +++ b/src/domain/token/entity.rs @@ -67,15 +67,19 @@ mod tests { assert_eq!(token.description, "Test"); assert_eq!(token.token, "abc123"); assert!(token.create_at < Utc::now()); - assert_eq!(token.expire_at, token.create_at + Duration::days(TOKEN_EXPIRE_IN_DAYS)); + assert_eq!( + token.expire_at, + token.create_at + Duration::days(TOKEN_EXPIRE_IN_DAYS) + ); } #[test] fn test_token_display() { let token = Token::new(1, "Test".to_string(), "abc123".to_string()).unwrap(); - let expected = format!("id: {}, user_id: {}, expire_at: {}", - token.id, token.user_id, token.expire_at); + let expected = format!( + "id: {}, user_id: {}, expire_at: {}", + token.id, token.user_id, token.expire_at + ); assert_eq!(expected, format!("{}", token)); } } - diff --git a/src/domain/token/mod.rs b/src/domain/token/mod.rs index 184732b..cdece9e 100644 --- a/src/domain/token/mod.rs +++ b/src/domain/token/mod.rs @@ -1,2 +1,2 @@ pub mod entity; -pub mod repository; \ No newline at end of file +pub mod repository; diff --git a/src/domain/token/repository.rs b/src/domain/token/repository.rs index b17e527..021fad1 100644 --- a/src/domain/token/repository.rs +++ b/src/domain/token/repository.rs @@ -22,7 +22,7 @@ use async_trait::async_trait; pub trait Repository: Send + Sync { async fn create(&self, user: Token) -> Result; async fn get_token_by_id(&self, id: i32) -> Result; - async fn get_token_by_value(&self, token: &str) -> Result; + async fn get_token_by_value(&self, token: &str) -> Result; async fn delete_by_user_and_id(&self, id: i32, user_id: i32) -> Result<()>; async fn get_token_by_user_id(&self, id: i32) -> Result>; } diff --git a/src/domain/user/entity.rs b/src/domain/user/entity.rs index ad72d75..68444a7 100644 --- a/src/domain/user/entity.rs +++ b/src/domain/user/entity.rs @@ -17,31 +17,21 @@ use crate::util::error::Result; use std::fmt::{Display, Formatter}; - - #[derive(Debug, Clone, PartialEq)] pub struct User { pub id: i32, - pub email: String - + pub email: String, } impl Display for User { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "id: {}, email: {}", - self.id, self.email - ) + write!(f, "id: {}, email: {}", self.id, self.email) } } impl User { pub fn new(email: String) -> Result { - Ok(User { - id: 0, - email, - }) + Ok(User { id: 0, email }) } } @@ -63,4 +53,3 @@ mod tests { assert_eq!(expected, format!("{}", user)); } } - diff --git a/src/domain/user/mod.rs b/src/domain/user/mod.rs index 184732b..cdece9e 100644 --- a/src/domain/user/mod.rs +++ b/src/domain/user/mod.rs @@ -1,2 +1,2 @@ pub mod entity; -pub mod repository; \ No newline at end of file +pub mod repository; diff --git a/src/infra/database/model/clusterkey/dto.rs b/src/infra/database/model/clusterkey/dto.rs index d16183f..3168734 100644 --- a/src/infra/database/model/clusterkey/dto.rs +++ b/src/infra/database/model/clusterkey/dto.rs @@ -14,12 +14,11 @@ * */ - use crate::domain::clusterkey::entity::ClusterKey; -use sqlx::types::chrono; use sea_orm::entity::prelude::*; use serde::{Deserialize, Serialize}; +use sqlx::types::chrono; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] #[sea_orm(table_name = "cluster_key")] @@ -51,8 +50,8 @@ impl From for ClusterKey { #[cfg(test)] mod tests { + use super::{ClusterKey, Model}; use chrono::Utc; - use super::{ClusterKey,Model}; #[test] fn test_cluster_key_entity_from_dto() { @@ -61,7 +60,7 @@ mod tests { data: vec![1, 2, 3], algorithm: "algo".to_string(), identity: "id".to_string(), - create_at: Utc::now() + create_at: Utc::now(), }; let create_at = dto.create_at.clone(); @@ -72,6 +71,4 @@ mod tests { assert_eq!(key.identity, "id"); assert_eq!(key.create_at, create_at); } - } - diff --git a/src/infra/database/model/clusterkey/repository.rs b/src/infra/database/model/clusterkey/repository.rs index edf8334..7ed8ab9 100644 --- a/src/infra/database/model/clusterkey/repository.rs +++ b/src/infra/database/model/clusterkey/repository.rs @@ -15,13 +15,15 @@ */ use super::dto::Entity as ClusterKeyDTO; -use crate::infra::database::model::clusterkey; -use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, ActiveValue::Set, QueryOrder}; use crate::domain::clusterkey::entity::ClusterKey; use crate::domain::clusterkey::repository::Repository; -use crate::util::error::{Result, Error}; +use crate::infra::database::model::clusterkey; +use crate::util::error::{Error, Result}; use async_trait::async_trait; use sea_orm::sea_query::OnConflict; +use sea_orm::{ + ActiveValue::Set, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, QueryOrder, +}; #[derive(Clone)] pub struct ClusterKeyRepository<'a> { @@ -30,9 +32,7 @@ pub struct ClusterKeyRepository<'a> { impl<'a> ClusterKeyRepository<'a> { pub fn new(db_connection: &'a DatabaseConnection) -> Self { - Self { - db_connection, - } + Self { db_connection } } } @@ -47,73 +47,74 @@ impl<'a> Repository for ClusterKeyRepository<'a> { ..Default::default() }; //TODO: https://github.com/SeaQL/sea-orm/issues/1790 - ClusterKeyDTO::insert(cluster_key).on_conflict(OnConflict::new() - .update_column(clusterkey::dto::Column::Id).to_owned() - ).exec(self.db_connection).await?; + ClusterKeyDTO::insert(cluster_key) + .on_conflict( + OnConflict::new() + .update_column(clusterkey::dto::Column::Id) + .to_owned(), + ) + .exec(self.db_connection) + .await?; Ok(()) } async fn get_latest(&self, algorithm: &str) -> Result> { - match ClusterKeyDTO::find().filter( - clusterkey::dto::Column::Algorithm.eq(algorithm) - ).order_by_desc(clusterkey::dto::Column::Id).one( - self.db_connection).await? { - None => { - Ok(None) - } - Some(cluster_key) => { - Ok(Some(ClusterKey::from(cluster_key))) - } + match ClusterKeyDTO::find() + .filter(clusterkey::dto::Column::Algorithm.eq(algorithm)) + .order_by_desc(clusterkey::dto::Column::Id) + .one(self.db_connection) + .await? + { + None => Ok(None), + Some(cluster_key) => Ok(Some(ClusterKey::from(cluster_key))), } } async fn get_by_id(&self, id: i32) -> Result { - match ClusterKeyDTO::find_by_id(id).one(self.db_connection).await? { - None => { - Err(Error::NotFoundError) - } - Some(cluster_key) => { - Ok(ClusterKey::from(cluster_key)) - } + match ClusterKeyDTO::find_by_id(id) + .one(self.db_connection) + .await? + { + None => Err(Error::NotFoundError), + Some(cluster_key) => Ok(ClusterKey::from(cluster_key)), } } async fn delete_by_id(&self, id: i32) -> Result<()> { - let _ = ClusterKeyDTO::delete_by_id( - id).exec(self.db_connection).await?; + let _ = ClusterKeyDTO::delete_by_id(id) + .exec(self.db_connection) + .await?; Ok(()) } } #[cfg(test)] mod tests { - use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction}; use crate::domain::clusterkey::entity::ClusterKey; use crate::domain::clusterkey::repository::Repository; use crate::infra::database::model::clusterkey::dto; + use crate::infra::database::model::clusterkey::repository::ClusterKeyRepository; use crate::util::error::Result; - use crate::infra::database::model::clusterkey::repository::{ClusterKeyRepository}; + use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction}; #[tokio::test] async fn test_cluster_key_repository_create_sql_statement() -> Result<()> { let now = chrono::Utc::now(); let db = MockDatabase::new(DatabaseBackend::MySql) - .append_query_results([ - vec![dto::Model { - id: 0, - data: vec![], - algorithm: "".to_string(), - identity: "".to_string(), - create_at: now.clone(), - }], - ]).append_exec_results([ - MockExecResult{ + .append_query_results([vec![dto::Model { + id: 0, + data: vec![], + algorithm: "".to_string(), + identity: "".to_string(), + create_at: now.clone(), + }]]) + .append_exec_results([MockExecResult { last_insert_id: 1, rows_affected: 1, - } - ]).into_connection(); + }]) + .into_connection(); let key_repository = ClusterKeyRepository::new(&db); - let key = ClusterKey{ + let key = ClusterKey { id: 0, data: vec![], algorithm: "fake_algorithm".to_string(), @@ -128,7 +129,12 @@ mod tests { Transaction::from_sql_and_values( DatabaseBackend::MySql, r#"INSERT INTO `cluster_key` (`data`, `algorithm`, `identity`, `create_at`) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE `id` = VALUES(`id`)"#, - [vec![].into(), "fake_algorithm".into(), "123".into(), now.clone().into()] + [ + vec![].into(), + "fake_algorithm".into(), + "123".into(), + now.clone().into() + ] ), ] ); @@ -140,20 +146,18 @@ mod tests { async fn test_cluster_key_repository_delete_sql_statement() -> Result<()> { let now = chrono::Utc::now(); let db = MockDatabase::new(DatabaseBackend::MySql) - .append_query_results([ - vec![dto::Model { - id: 1, - data: vec![], - algorithm: "fake_algorithm".to_string(), - identity: "123".to_string(), - create_at: now.clone(), - }], - ]).append_exec_results([ - MockExecResult{ + .append_query_results([vec![dto::Model { + id: 1, + data: vec![], + algorithm: "fake_algorithm".to_string(), + identity: "123".to_string(), + create_at: now.clone(), + }]]) + .append_exec_results([MockExecResult { last_insert_id: 1, rows_affected: 1, - } - ]).into_connection(); + }]) + .into_connection(); let key_repository = ClusterKeyRepository::new(&db); assert_eq!(key_repository.delete_by_id(1).await?, ()); @@ -185,14 +189,14 @@ mod tests { create_at: now.clone(), }], vec![dto::Model { - id: 2, - data: vec![], - algorithm: "fake_algorithm".to_string(), - identity: "123".to_string(), - create_at: now.clone(), - }], - ], - ).into_connection(); + id: 2, + data: vec![], + algorithm: "fake_algorithm".to_string(), + identity: "123".to_string(), + create_at: now.clone(), + }], + ]) + .into_connection(); let key_repository = ClusterKeyRepository::new(&db); assert_eq!( @@ -236,4 +240,3 @@ mod tests { Ok(()) } } - diff --git a/src/infra/database/model/datakey/dto.rs b/src/infra/database/model/datakey/dto.rs index 0954fef..e890217 100644 --- a/src/infra/database/model/datakey/dto.rs +++ b/src/infra/database/model/datakey/dto.rs @@ -14,18 +14,18 @@ * */ -use crate::domain::datakey::entity::{DataKey, KeyState, Visibility}; use crate::domain::datakey::entity::KeyType; +use crate::domain::datakey::entity::{DataKey, KeyState, Visibility}; use crate::domain::datakey::traits::ExtendableAttributes; -use crate::util::error::{Error}; +use crate::util::error::Error; use crate::util::key; use chrono::{DateTime, Utc}; -use std::str::FromStr; use sea_orm::ActiveValue::Set; +use std::str::FromStr; use sea_orm::entity::prelude::*; -use sea_orm::{NotSet}; +use sea_orm::NotSet; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] @@ -51,10 +51,9 @@ pub struct Model { pub user_email: Option, pub request_delete_users: Option, pub request_revoke_users: Option, - pub x509_crl_update_at: Option> + pub x509_crl_update_at: Option>, } - #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] pub enum Relation { #[sea_orm(has_one = "super::super::user::dto::Entity")] @@ -77,8 +76,6 @@ impl Related for Entity { impl ActiveModelBehavior for ActiveModel {} - - impl TryFrom for DataKey { type Error = Error; @@ -123,15 +120,9 @@ impl TryFrom for ActiveModel { parent_id: Set(data_key.parent_id), fingerprint: Set(data_key.fingerprint.clone()), serial_number: Set(data_key.serial_number), - private_key: Set(key::encode_u8_to_hex_string( - &data_key.private_key) - ), - public_key: Set(key::encode_u8_to_hex_string( - &data_key.public_key) - ), - certificate: Set(key::encode_u8_to_hex_string( - &data_key.certificate - )), + private_key: Set(key::encode_u8_to_hex_string(&data_key.private_key)), + public_key: Set(key::encode_u8_to_hex_string(&data_key.public_key)), + certificate: Set(key::encode_u8_to_hex_string(&data_key.certificate)), create_at: Set(data_key.create_at), expire_at: Set(data_key.expire_at), key_state: Set(data_key.key_state.to_string()), @@ -146,7 +137,7 @@ impl TryFrom for ActiveModel { #[cfg(test)] mod tests { use super::*; - use crate::domain::datakey::entity::{Visibility}; + use crate::domain::datakey::entity::Visibility; #[test] fn test_data_key_entity_from_dto() { @@ -177,10 +168,8 @@ mod tests { assert_eq!(key.name, "Test Key"); assert_eq!(key.visibility, Visibility::Public); assert_eq!(key.key_type, KeyType::OpenPGP); - assert_eq!(key.private_key, vec![7,8,9,10]); - assert_eq!(key.public_key, vec![4,5,6]); - assert_eq!(key.certificate, vec![1,2,3]); + assert_eq!(key.private_key, vec![7, 8, 9, 10]); + assert_eq!(key.public_key, vec![4, 5, 6]); + assert_eq!(key.certificate, vec![1, 2, 3]); } - } - diff --git a/src/infra/database/model/datakey/repository.rs b/src/infra/database/model/datakey/repository.rs index bfd1c1a..09f7f3a 100644 --- a/src/infra/database/model/datakey/repository.rs +++ b/src/infra/database/model/datakey/repository.rs @@ -14,39 +14,49 @@ * */ -use super::dto as datakey_dto; -use super::super::user::dto as user_dto; use super::super::request_delete::dto as request_dto; +use super::super::user::dto as user_dto; use super::super::x509_crl_content::dto as crl_content_dto; use super::super::x509_revoked_key::dto as revoked_key_dto; -use crate::domain::datakey::entity::{DataKey, DatakeyPaginationQuery, KeyState, KeyType, PagedDatakey, PagedMeta, ParentKey, RevokedKey, Visibility, X509CRL, X509RevokeReason}; +use super::dto as datakey_dto; +use crate::domain::datakey::entity::{ + DataKey, DatakeyPaginationQuery, KeyState, KeyType, PagedDatakey, PagedMeta, ParentKey, + RevokedKey, Visibility, X509RevokeReason, X509CRL, +}; use crate::domain::datakey::repository::Repository; +use crate::infra::database::model::request_delete::dto::RequestType; use crate::util::error::{Error, Result}; +use crate::util::key::encode_u8_to_hex_string; use async_trait::async_trait; use chrono::Duration; -use sea_query::Expr; use chrono::Utc; -use sea_orm::{Condition, Iterable, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, ActiveValue::Set, TransactionTrait, DatabaseTransaction, QuerySelect, JoinType, sea_query, RelationTrait, RelationBuilder, ExecResult, ConnectionTrait, Statement, DatabaseBackend, NotSet, PaginatorTrait}; use sea_orm::sea_query::{Alias, IntoCondition, OnConflict}; -use crate::infra::database::model::request_delete::dto::RequestType; -use crate::util::key::encode_u8_to_hex_string; +use sea_orm::{ + sea_query, ActiveValue::Set, ColumnTrait, Condition, ConnectionTrait, DatabaseBackend, + DatabaseConnection, DatabaseTransaction, EntityTrait, ExecResult, Iterable, JoinType, NotSet, + PaginatorTrait, QueryFilter, QuerySelect, RelationBuilder, RelationTrait, Statement, + TransactionTrait, +}; +use sea_query::Expr; const PUBLIC_KEY_PENDING_THRESHOLD: i32 = 3; const PRIVATE_KEY_PENDING_THRESHOLD: i32 = 1; #[derive(Clone)] pub struct DataKeyRepository<'a> { - db_connection: &'a DatabaseConnection + db_connection: &'a DatabaseConnection, } impl<'a> DataKeyRepository<'a> { pub fn new(db_connection: &'a DatabaseConnection) -> Self { - Self { - db_connection - } + Self { db_connection } } - async fn create_pending_operation(&self, pending_operation: request_dto::Model, tx: &DatabaseTransaction) -> Result<()> { + async fn create_pending_operation( + &self, + pending_operation: request_dto::Model, + tx: &DatabaseTransaction, + ) -> Result<()> { let operation = request_dto::ActiveModel { user_id: Set(pending_operation.user_id), key_id: Set(pending_operation.key_id), @@ -56,23 +66,44 @@ impl<'a> DataKeyRepository<'a> { ..Default::default() }; //TODO: https://github.com/SeaQL/sea-orm/issues/1790 - request_dto::Entity::insert(operation).on_conflict(OnConflict::new() - .update_column(request_dto::Column::Id).to_owned() - ).exec(tx).await?; + request_dto::Entity::insert(operation) + .on_conflict( + OnConflict::new() + .update_column(request_dto::Column::Id) + .to_owned(), + ) + .exec(tx) + .await?; Ok(()) } - async fn delete_pending_operation(&self, user_id: i32, id: i32, request_type: RequestType, tx: &DatabaseTransaction) -> Result<()> { - let _ = request_dto::Entity::delete_many().filter(Condition::all() - .add(request_dto::Column::UserId.eq(user_id)) - .add(request_dto::Column::RequestType.eq(request_type.to_string())) - .add(request_dto::Column::KeyId.eq(id))).exec(tx) + async fn delete_pending_operation( + &self, + user_id: i32, + id: i32, + request_type: RequestType, + tx: &DatabaseTransaction, + ) -> Result<()> { + let _ = request_dto::Entity::delete_many() + .filter( + Condition::all() + .add(request_dto::Column::UserId.eq(user_id)) + .add(request_dto::Column::RequestType.eq(request_type.to_string())) + .add(request_dto::Column::KeyId.eq(id)), + ) + .exec(tx) .await?; Ok(()) } - async fn create_revoke_record(&self, key_id: i32, ca_id: i32, reason: X509RevokeReason, tx: &DatabaseTransaction) -> Result<()> { - let revoked = revoked_key_dto::ActiveModel{ + async fn create_revoke_record( + &self, + key_id: i32, + ca_id: i32, + reason: X509RevokeReason, + tx: &DatabaseTransaction, + ) -> Result<()> { + let revoked = revoked_key_dto::ActiveModel { id: Default::default(), key_id: Set(key_id), ca_id: Set(ca_id), @@ -81,27 +112,46 @@ impl<'a> DataKeyRepository<'a> { serial_number: NotSet, }; //TODO: https://github.com/SeaQL/sea-orm/issues/1790 - revoked_key_dto::Entity::insert(revoked).on_conflict(OnConflict::new() - .update_column(request_dto::Column::Id).to_owned() - ).exec(tx).await?; + revoked_key_dto::Entity::insert(revoked) + .on_conflict( + OnConflict::new() + .update_column(request_dto::Column::Id) + .to_owned(), + ) + .exec(tx) + .await?; Ok(()) } - async fn delete_revoke_record(&self, key_id: i32, ca_id: i32, tx: &DatabaseTransaction) -> Result<()> { - let _ = revoked_key_dto::Entity::delete_many().filter(Condition::all() - .add(revoked_key_dto::Column::KeyId.eq(key_id)) - .add(revoked_key_dto::Column::CaId.eq(ca_id))).exec(tx) + async fn delete_revoke_record( + &self, + key_id: i32, + ca_id: i32, + tx: &DatabaseTransaction, + ) -> Result<()> { + let _ = revoked_key_dto::Entity::delete_many() + .filter( + Condition::all() + .add(revoked_key_dto::Column::KeyId.eq(key_id)) + .add(revoked_key_dto::Column::CaId.eq(ca_id)), + ) + .exec(tx) .await?; Ok(()) } - fn get_pending_operation_relation(&self, request_type: RequestType) -> RelationBuilder { - request_dto::Entity::belongs_to(datakey_dto::Entity).from( - request_dto::Column::KeyId).to( - datakey_dto::Column::Id).on_condition(move |left, _right| { - Expr::col((left, request_dto::Column::RequestType)).eq(request_type.clone().to_string()).into_condition() - } - ) + fn get_pending_operation_relation( + &self, + request_type: RequestType, + ) -> RelationBuilder { + request_dto::Entity::belongs_to(datakey_dto::Entity) + .from(request_dto::Column::KeyId) + .to(datakey_dto::Column::Id) + .on_condition(move |left, _right| { + Expr::col((left, request_dto::Column::RequestType)) + .eq(request_type.clone().to_string()) + .into_condition() + }) } async fn _obtain_datakey_parent(&self, datakey: &mut DataKey) -> Result<()> { @@ -114,11 +164,13 @@ impl<'a> DataKeyRepository<'a> { private_key: parent.private_key.clone(), public_key: parent.public_key.clone(), certificate: parent.certificate.clone(), - attributes: parent.attributes + attributes: parent.attributes, }) } _ => { - return Err(Error::DatabaseError("unable to find parent key".to_string())); + return Err(Error::DatabaseError( + "unable to find parent key".to_string(), + )); } } } @@ -130,9 +182,13 @@ impl<'a> DataKeyRepository<'a> { impl<'a> Repository for DataKeyRepository<'a> { async fn create(&self, data_key: DataKey) -> Result { let dto = datakey_dto::ActiveModel::try_from(data_key)?; - let insert_result =datakey_dto::Entity::insert(dto).exec(self.db_connection).await?; + let insert_result = datakey_dto::Entity::insert(dto) + .exec(self.db_connection) + .await?; - let mut datakey = self.get_by_id_or_name(Some(insert_result.last_insert_id), None, true).await?; + let mut datakey = self + .get_by_id_or_name(Some(insert_result.last_insert_id), None, true) + .await?; //fetch parent key if 'parent_id' exists. if let Err(err) = self._obtain_datakey_parent(&mut datakey).await { warn!("failed to create datakey {} {}", datakey.name, err); @@ -143,17 +199,25 @@ impl<'a> Repository for DataKeyRepository<'a> { } async fn delete(&self, id: i32) -> Result<()> { - datakey_dto::Entity::delete_by_id(id).exec(self.db_connection).await?; + datakey_dto::Entity::delete_by_id(id) + .exec(self.db_connection) + .await?; Ok(()) } - async fn get_all_keys(&self, user_id: i32, query: DatakeyPaginationQuery) -> Result { - let mut conditions = Condition::all().add(datakey_dto::Column::KeyState.ne(KeyState::Deleted.to_string())); + async fn get_all_keys( + &self, + user_id: i32, + query: DatakeyPaginationQuery, + ) -> Result { + let mut conditions = + Condition::all().add(datakey_dto::Column::KeyState.ne(KeyState::Deleted.to_string())); if let Some(name) = query.name { conditions = conditions.add(datakey_dto::Column::Name.like(format!("%{}%", name))) } if let Some(desc) = query.description { - conditions = conditions.add(datakey_dto::Column::Description.like(format!("%{}%", desc))) + conditions = + conditions.add(datakey_dto::Column::Description.like(format!("%{}%", desc))) } if let Some(k_type) = query.key_type { conditions = conditions.add(datakey_dto::Column::KeyType.eq(k_type)) @@ -164,45 +228,94 @@ impl<'a> Repository for DataKeyRepository<'a> { conditions = conditions.add(datakey_dto::Column::User.eq(user_id)) } } - let paginator = datakey_dto::Entity::find().select_only().columns( - datakey_dto::Column::iter().filter(|col| - !matches!(col, datakey_dto::Column::UserEmail | datakey_dto::Column::RequestDeleteUsers | datakey_dto::Column::RequestRevokeUsers | datakey_dto::Column::X509CrlUpdateAt))).exprs( - [Expr::cust("user_table.email as user_email"), - Expr::cust("GROUP_CONCAT(request_delete_table.user_email) as request_delete_users"), - Expr::cust("GROUP_CONCAT(request_revoke_table.user_email) as request_revoke_users")]).join_as_rev( - JoinType::InnerJoin, user_dto::Relation::Datakey.def(), Alias::new("user_table")).join_as_rev( - JoinType::LeftJoin, self.get_pending_operation_relation(RequestType::Delete).into(), - Alias::new("request_delete_table")).join_as_rev( - JoinType::LeftJoin, self.get_pending_operation_relation(RequestType::Revoke).into(), - Alias::new("request_revoke_table")).group_by(datakey_dto::Column::Id).filter(conditions).paginate(self.db_connection, query.page_size); + let paginator = datakey_dto::Entity::find() + .select_only() + .columns(datakey_dto::Column::iter().filter(|col| { + !matches!( + col, + datakey_dto::Column::UserEmail + | datakey_dto::Column::RequestDeleteUsers + | datakey_dto::Column::RequestRevokeUsers + | datakey_dto::Column::X509CrlUpdateAt + ) + })) + .exprs([ + Expr::cust("user_table.email as user_email"), + Expr::cust("GROUP_CONCAT(request_delete_table.user_email) as request_delete_users"), + Expr::cust("GROUP_CONCAT(request_revoke_table.user_email) as request_revoke_users"), + ]) + .join_as_rev( + JoinType::InnerJoin, + user_dto::Relation::Datakey.def(), + Alias::new("user_table"), + ) + .join_as_rev( + JoinType::LeftJoin, + self.get_pending_operation_relation(RequestType::Delete) + .into(), + Alias::new("request_delete_table"), + ) + .join_as_rev( + JoinType::LeftJoin, + self.get_pending_operation_relation(RequestType::Revoke) + .into(), + Alias::new("request_revoke_table"), + ) + .group_by(datakey_dto::Column::Id) + .filter(conditions) + .paginate(self.db_connection, query.page_size); let total_numbers = paginator.num_items().await?; let mut results = vec![]; - for dto in paginator.fetch_page(query.page_number - 1).await?.into_iter() { + for dto in paginator + .fetch_page(query.page_number - 1) + .await? + .into_iter() + { results.push(DataKey::try_from(dto)?); } - Ok(PagedDatakey{ - data:results, - meta: PagedMeta{ + Ok(PagedDatakey { + data: results, + meta: PagedMeta { total_count: total_numbers, - } + }, }) } async fn get_keys_for_crl_update(&self, duration: Duration) -> Result> { let now = Utc::now(); - match datakey_dto::Entity::find().select_only().columns( - datakey_dto::Column::iter().filter(|col| - !matches!(col, datakey_dto::Column::UserEmail | datakey_dto::Column::RequestDeleteUsers | datakey_dto::Column::RequestRevokeUsers | datakey_dto::Column::X509CrlUpdateAt))).column_as( - Expr::col((Alias::new("crl_table"), crl_content_dto::Column::UpdateAt)), "x509_crl_update_at").join_as_rev( - JoinType::LeftJoin, crl_content_dto::Relation::Datakey.def(), Alias::new("crl_table")).filter( - Condition::all().add( - Condition::any().add(datakey_dto::Column::KeyType.eq(KeyType::X509CA.to_string()) - ).add(datakey_dto::Column::KeyType.eq(KeyType::X509ICA.to_string()))).add( - datakey_dto::Column::KeyState.ne(KeyState::Deleted.to_string())) - ).all(self.db_connection).await { - Err(_) => { - Ok(vec![]) - } + match datakey_dto::Entity::find() + .select_only() + .columns(datakey_dto::Column::iter().filter(|col| { + !matches!( + col, + datakey_dto::Column::UserEmail + | datakey_dto::Column::RequestDeleteUsers + | datakey_dto::Column::RequestRevokeUsers + | datakey_dto::Column::X509CrlUpdateAt + ) + })) + .column_as( + Expr::col((Alias::new("crl_table"), crl_content_dto::Column::UpdateAt)), + "x509_crl_update_at", + ) + .join_as_rev( + JoinType::LeftJoin, + crl_content_dto::Relation::Datakey.def(), + Alias::new("crl_table"), + ) + .filter( + Condition::all() + .add( + Condition::any() + .add(datakey_dto::Column::KeyType.eq(KeyType::X509CA.to_string())) + .add(datakey_dto::Column::KeyType.eq(KeyType::X509ICA.to_string())), + ) + .add(datakey_dto::Column::KeyState.ne(KeyState::Deleted.to_string())), + ) + .all(self.db_connection) + .await + { + Err(_) => Ok(vec![]), Ok(keys) => { let mut results = vec![]; for dto in keys.into_iter() { @@ -221,25 +334,39 @@ impl<'a> Repository for DataKeyRepository<'a> { } async fn get_revoked_serial_number_by_parent_id(&self, id: i32) -> Result> { - match revoked_key_dto::Entity::find().select_only().columns( - revoked_key_dto::Column::iter().filter(|col| - match col { - revoked_key_dto::Column::SerialNumber => false, - _ => true, - })).column_as( - Expr::col((Alias::new("datakey_table"), datakey_dto::Column::SerialNumber)), - "serial_number").join_as_rev( - JoinType::InnerJoin, datakey_dto::Entity::belongs_to(revoked_key_dto::Entity).from( - datakey_dto::Column::Id).to( - revoked_key_dto::Column::KeyId).on_condition( - move |left, right| { - Condition::all().add( - Expr::col((left, datakey_dto::Column::KeyState)).eq(KeyState::Revoked.to_string())).add( - Expr::col((right, revoked_key_dto::Column::CaId)).eq(id) - ).into_condition() - } - ).into(), - Alias::new("datakey_table")).all(self.db_connection).await { + match revoked_key_dto::Entity::find() + .select_only() + .columns(revoked_key_dto::Column::iter().filter(|col| match col { + revoked_key_dto::Column::SerialNumber => false, + _ => true, + })) + .column_as( + Expr::col(( + Alias::new("datakey_table"), + datakey_dto::Column::SerialNumber, + )), + "serial_number", + ) + .join_as_rev( + JoinType::InnerJoin, + datakey_dto::Entity::belongs_to(revoked_key_dto::Entity) + .from(datakey_dto::Column::Id) + .to(revoked_key_dto::Column::KeyId) + .on_condition(move |left, right| { + Condition::all() + .add( + Expr::col((left, datakey_dto::Column::KeyState)) + .eq(KeyState::Revoked.to_string()), + ) + .add(Expr::col((right, revoked_key_dto::Column::CaId)).eq(id)) + .into_condition() + }) + .into(), + Alias::new("datakey_table"), + ) + .all(self.db_connection) + .await + { Err(err) => { warn!("failed to query database {:?}", err); Err(Error::NotFoundError) @@ -254,67 +381,135 @@ impl<'a> Repository for DataKeyRepository<'a> { } } - async fn get_by_id_or_name(&self, id: Option, name: Option, raw_datakey: bool) -> Result { + async fn get_by_id_or_name( + &self, + id: Option, + name: Option, + raw_datakey: bool, + ) -> Result { let mut conditions = Condition::all(); if let Some(key_id) = id { conditions = conditions.add(datakey_dto::Column::Id.eq(key_id)) } else if let Some(key_name) = name { conditions = conditions.add(datakey_dto::Column::Name.eq(key_name)) } else { - return Err(Error::ParameterError("both datakey name and id are empty".to_string())) + return Err(Error::ParameterError( + "both datakey name and id are empty".to_string(), + )); } - conditions = conditions.add(datakey_dto::Column::KeyState.ne(KeyState::Deleted.to_string())); + conditions = + conditions.add(datakey_dto::Column::KeyState.ne(KeyState::Deleted.to_string())); if !raw_datakey { - match datakey_dto::Entity::find().select_only().columns( - datakey_dto::Column::iter().filter(|col| - !matches!(col, datakey_dto::Column::UserEmail | datakey_dto::Column::RequestDeleteUsers | datakey_dto::Column::RequestRevokeUsers | datakey_dto::Column::X509CrlUpdateAt))).exprs( - [Expr::cust("user_table.email as user_email"), - Expr::cust("GROUP_CONCAT(request_delete_table.user_email) as request_delete_users"), - Expr::cust("GROUP_CONCAT(request_revoke_table.user_email) as request_revoke_users")]).join_as_rev( - JoinType::InnerJoin, user_dto::Relation::Datakey.def(), Alias::new("user_table")).join_as_rev( - JoinType::LeftJoin, self.get_pending_operation_relation(RequestType::Delete).into(), - Alias::new("request_delete_table")).join_as_rev( - JoinType::LeftJoin, self.get_pending_operation_relation(RequestType::Revoke).into(), - Alias::new("request_revoke_table")).group_by(datakey_dto::Column::Id).filter( - conditions).one(self.db_connection).await? { - None => { - Err(Error::NotFoundError) - } - Some(datakey) => { - Ok(DataKey::try_from(datakey)?) - } + match datakey_dto::Entity::find() + .select_only() + .columns(datakey_dto::Column::iter().filter(|col| { + !matches!( + col, + datakey_dto::Column::UserEmail + | datakey_dto::Column::RequestDeleteUsers + | datakey_dto::Column::RequestRevokeUsers + | datakey_dto::Column::X509CrlUpdateAt + ) + })) + .exprs([ + Expr::cust("user_table.email as user_email"), + Expr::cust( + "GROUP_CONCAT(request_delete_table.user_email) as request_delete_users", + ), + Expr::cust( + "GROUP_CONCAT(request_revoke_table.user_email) as request_revoke_users", + ), + ]) + .join_as_rev( + JoinType::InnerJoin, + user_dto::Relation::Datakey.def(), + Alias::new("user_table"), + ) + .join_as_rev( + JoinType::LeftJoin, + self.get_pending_operation_relation(RequestType::Delete) + .into(), + Alias::new("request_delete_table"), + ) + .join_as_rev( + JoinType::LeftJoin, + self.get_pending_operation_relation(RequestType::Revoke) + .into(), + Alias::new("request_revoke_table"), + ) + .group_by(datakey_dto::Column::Id) + .filter(conditions) + .one(self.db_connection) + .await? + { + None => Err(Error::NotFoundError), + Some(datakey) => Ok(DataKey::try_from(datakey)?), } } else { - match datakey_dto::Entity::find().select_only().columns( - datakey_dto::Column::iter().filter(|col| - !matches!(col, datakey_dto::Column::UserEmail | datakey_dto::Column::RequestDeleteUsers | datakey_dto::Column::RequestRevokeUsers | datakey_dto::Column::X509CrlUpdateAt))).filter(conditions).one(self.db_connection).await? { - None => { - Err(Error::NotFoundError) - } - Some(datakey) => { - Ok(DataKey::try_from(datakey)?) - } + match datakey_dto::Entity::find() + .select_only() + .columns(datakey_dto::Column::iter().filter(|col| { + !matches!( + col, + datakey_dto::Column::UserEmail + | datakey_dto::Column::RequestDeleteUsers + | datakey_dto::Column::RequestRevokeUsers + | datakey_dto::Column::X509CrlUpdateAt + ) + })) + .filter(conditions) + .one(self.db_connection) + .await? + { + None => Err(Error::NotFoundError), + Some(datakey) => Ok(DataKey::try_from(datakey)?), } } - } async fn get_by_parent_id(&self, parent_id: i32) -> Result> { - match datakey_dto::Entity::find().select_only().columns( - datakey_dto::Column::iter().filter(|col| - !matches!(col, datakey_dto::Column::UserEmail | datakey_dto::Column::RequestDeleteUsers | datakey_dto::Column::RequestRevokeUsers | datakey_dto::Column::X509CrlUpdateAt))).exprs( - [Expr::cust("user_table.email as user_email"), - Expr::cust("GROUP_CONCAT(request_delete_table.user_email) as request_delete_users"), - Expr::cust("GROUP_CONCAT(request_revoke_table.user_email) as request_revoke_users")]).join_as_rev( - JoinType::InnerJoin, user_dto::Relation::Datakey.def(), Alias::new("user_table")).join_as_rev( - JoinType::LeftJoin, self.get_pending_operation_relation(RequestType::Delete).into(), - Alias::new("request_delete_table")).join_as_rev( - JoinType::LeftJoin, self.get_pending_operation_relation(RequestType::Revoke).into(), - Alias::new("request_revoke_table")).group_by(datakey_dto::Column::Id).filter( - Condition::all().add( - datakey_dto::Column::ParentId.eq(parent_id)).add( - datakey_dto::Column::KeyState.ne(KeyState::Deleted.to_string())) - ).all(self.db_connection).await { + match datakey_dto::Entity::find() + .select_only() + .columns(datakey_dto::Column::iter().filter(|col| { + !matches!( + col, + datakey_dto::Column::UserEmail + | datakey_dto::Column::RequestDeleteUsers + | datakey_dto::Column::RequestRevokeUsers + | datakey_dto::Column::X509CrlUpdateAt + ) + })) + .exprs([ + Expr::cust("user_table.email as user_email"), + Expr::cust("GROUP_CONCAT(request_delete_table.user_email) as request_delete_users"), + Expr::cust("GROUP_CONCAT(request_revoke_table.user_email) as request_revoke_users"), + ]) + .join_as_rev( + JoinType::InnerJoin, + user_dto::Relation::Datakey.def(), + Alias::new("user_table"), + ) + .join_as_rev( + JoinType::LeftJoin, + self.get_pending_operation_relation(RequestType::Delete) + .into(), + Alias::new("request_delete_table"), + ) + .join_as_rev( + JoinType::LeftJoin, + self.get_pending_operation_relation(RequestType::Revoke) + .into(), + Alias::new("request_revoke_table"), + ) + .group_by(datakey_dto::Column::Id) + .filter( + Condition::all() + .add(datakey_dto::Column::ParentId.eq(parent_id)) + .add(datakey_dto::Column::KeyState.ne(KeyState::Deleted.to_string())), + ) + .all(self.db_connection) + .await + { Err(err) => { warn!("failed to query database {:?}", err); Err(Error::NotFoundError) @@ -331,46 +526,80 @@ impl<'a> Repository for DataKeyRepository<'a> { async fn update_state(&self, id: i32, state: KeyState) -> Result<()> { //Note: if the key in deleted status, it cannot be updated to other states - let _ = datakey_dto::Entity::update_many().col_expr( - datakey_dto::Column::KeyState, Expr::value(state.to_string()) - ).filter(Condition::all().add( - datakey_dto::Column::Id.eq(id)).add( - datakey_dto::Column::KeyState.ne(KeyState::Deleted.to_string())) - ).exec(self.db_connection).await?; + let _ = datakey_dto::Entity::update_many() + .col_expr( + datakey_dto::Column::KeyState, + Expr::value(state.to_string()), + ) + .filter( + Condition::all() + .add(datakey_dto::Column::Id.eq(id)) + .add(datakey_dto::Column::KeyState.ne(KeyState::Deleted.to_string())), + ) + .exec(self.db_connection) + .await?; Ok(()) } async fn update_key_data(&self, data_key: DataKey) -> Result<()> { //Note: if the key in deleted status, it cannot be updated to other states - let _ = datakey_dto::Entity::update_many().col_expr( - datakey_dto::Column::SerialNumber, Expr::value(data_key.serial_number) - ).col_expr( - datakey_dto::Column::Fingerprint, Expr::value(data_key.fingerprint) - ).col_expr( - datakey_dto::Column::PrivateKey, Expr::value(encode_u8_to_hex_string(&data_key.private_key)) - ).col_expr( - datakey_dto::Column::PublicKey, Expr::value(encode_u8_to_hex_string(&data_key.public_key)) - ).col_expr( - datakey_dto::Column::Certificate, Expr::value(encode_u8_to_hex_string(&data_key.certificate)) - ).filter(Condition::all().add( - datakey_dto::Column::Id.eq(data_key.id)).add( - datakey_dto::Column::KeyState.ne(KeyState::Deleted.to_string())) - ).exec(self.db_connection).await?; + let _ = datakey_dto::Entity::update_many() + .col_expr( + datakey_dto::Column::SerialNumber, + Expr::value(data_key.serial_number), + ) + .col_expr( + datakey_dto::Column::Fingerprint, + Expr::value(data_key.fingerprint), + ) + .col_expr( + datakey_dto::Column::PrivateKey, + Expr::value(encode_u8_to_hex_string(&data_key.private_key)), + ) + .col_expr( + datakey_dto::Column::PublicKey, + Expr::value(encode_u8_to_hex_string(&data_key.public_key)), + ) + .col_expr( + datakey_dto::Column::Certificate, + Expr::value(encode_u8_to_hex_string(&data_key.certificate)), + ) + .filter( + Condition::all() + .add(datakey_dto::Column::Id.eq(data_key.id)) + .add(datakey_dto::Column::KeyState.ne(KeyState::Deleted.to_string())), + ) + .exec(self.db_connection) + .await?; Ok(()) } - async fn get_enabled_key_by_type_and_name_with_parent_key(&self, key_type: String, name: String) -> Result { - match datakey_dto::Entity::find().select_only().columns( - datakey_dto::Column::iter().filter(|col| - !matches!(col, datakey_dto::Column::UserEmail | datakey_dto::Column::RequestDeleteUsers | datakey_dto::Column::RequestRevokeUsers | datakey_dto::Column::X509CrlUpdateAt))).filter( - Condition::all().add( - datakey_dto::Column::Name.eq(name)).add( - datakey_dto::Column::KeyType.eq(key_type)).add( - datakey_dto::Column::KeyState.eq(KeyState::Enabled.to_string())) - ).one(self.db_connection).await? { - None => { - Err(Error::NotFoundError) - } + async fn get_enabled_key_by_type_and_name_with_parent_key( + &self, + key_type: String, + name: String, + ) -> Result { + match datakey_dto::Entity::find() + .select_only() + .columns(datakey_dto::Column::iter().filter(|col| { + !matches!( + col, + datakey_dto::Column::UserEmail + | datakey_dto::Column::RequestDeleteUsers + | datakey_dto::Column::RequestRevokeUsers + | datakey_dto::Column::X509CrlUpdateAt + ) + })) + .filter( + Condition::all() + .add(datakey_dto::Column::Name.eq(name)) + .add(datakey_dto::Column::KeyType.eq(key_type)) + .add(datakey_dto::Column::KeyState.eq(KeyState::Enabled.to_string())), + ) + .one(self.db_connection) + .await? + { + None => Err(Error::NotFoundError), Some(datakey) => { let mut result = DataKey::try_from(datakey)?; self._obtain_datakey_parent(&mut result).await?; @@ -379,48 +608,95 @@ impl<'a> Repository for DataKeyRepository<'a> { } } - async fn request_delete_key(&self, user_id: i32, user_email: String, id: i32, public_key: bool) -> Result<()> { + async fn request_delete_key( + &self, + user_id: i32, + user_email: String, + id: i32, + public_key: bool, + ) -> Result<()> { let txn = self.db_connection.begin().await?; - let threshold = if public_key { PUBLIC_KEY_PENDING_THRESHOLD } else { PRIVATE_KEY_PENDING_THRESHOLD }; + let threshold = if public_key { + PUBLIC_KEY_PENDING_THRESHOLD + } else { + PRIVATE_KEY_PENDING_THRESHOLD + }; //1. update key state to pending delete if needed. - let _ = datakey_dto::Entity::update_many().col_expr( - datakey_dto::Column::KeyState, Expr::value(KeyState::PendingDelete.to_string()) - ).filter(datakey_dto::Column::Id.eq(id)).exec(&txn).await?; + let _ = datakey_dto::Entity::update_many() + .col_expr( + datakey_dto::Column::KeyState, + Expr::value(KeyState::PendingDelete.to_string()), + ) + .filter(datakey_dto::Column::Id.eq(id)) + .exec(&txn) + .await?; //2. add request delete record let pending_delete = request_dto::Model::new_for_delete(id, user_id, user_email); self.create_pending_operation(pending_delete, &txn).await?; //3. delete datakey if pending delete count >= threshold - let _: ExecResult = txn.execute(Statement::from_sql_and_values( - DatabaseBackend::MySql, - "UPDATE data_key SET key_state = ? \ + let _: ExecResult = txn + .execute(Statement::from_sql_and_values( + DatabaseBackend::MySql, + "UPDATE data_key SET key_state = ? \ WHERE id = ? AND ( \ SELECT COUNT(*) FROM pending_operation WHERE key_id = ?) >= ?", - [KeyState::Deleted.to_string().into(), id.into(), id.into(), threshold.into()], - )).await?; + [ + KeyState::Deleted.to_string().into(), + id.into(), + id.into(), + threshold.into(), + ], + )) + .await?; txn.commit().await?; Ok(()) } - async fn request_revoke_key(&self, user_id: i32, user_email: String, id: i32, parent_id: i32, reason: X509RevokeReason, public_key: bool) -> Result<()> { + async fn request_revoke_key( + &self, + user_id: i32, + user_email: String, + id: i32, + parent_id: i32, + reason: X509RevokeReason, + public_key: bool, + ) -> Result<()> { let txn = self.db_connection.begin().await?; - let threshold = if public_key { PUBLIC_KEY_PENDING_THRESHOLD } else { PRIVATE_KEY_PENDING_THRESHOLD }; + let threshold = if public_key { + PUBLIC_KEY_PENDING_THRESHOLD + } else { + PRIVATE_KEY_PENDING_THRESHOLD + }; //1. update key state to pending delete if needed. - let _ = datakey_dto::Entity::update_many().col_expr( - datakey_dto::Column::KeyState, Expr::value(KeyState::PendingDelete.to_string()) - ).filter(datakey_dto::Column::Id.eq(id)).exec(&txn).await?; + let _ = datakey_dto::Entity::update_many() + .col_expr( + datakey_dto::Column::KeyState, + Expr::value(KeyState::PendingDelete.to_string()), + ) + .filter(datakey_dto::Column::Id.eq(id)) + .exec(&txn) + .await?; //2. add request revoke pending record let pending_revoke = request_dto::Model::new_for_revoke(id, user_id, user_email); self.create_pending_operation(pending_revoke, &txn).await?; //3. add revoked record - self.create_revoke_record(id, parent_id, reason, &txn).await?; + self.create_revoke_record(id, parent_id, reason, &txn) + .await?; //4. mark datakey revoked if pending revoke count >= threshold - let _: ExecResult = txn.execute(Statement::from_sql_and_values( - DatabaseBackend::MySql, - "UPDATE data_key SET key_state = ? \ + let _: ExecResult = txn + .execute(Statement::from_sql_and_values( + DatabaseBackend::MySql, + "UPDATE data_key SET key_state = ? \ WHERE id = ? AND ( \ SELECT COUNT(*) FROM pending_operation WHERE key_id = ?) >= ?", - [KeyState::Revoked.to_string().into(), id.into(), id.into(), threshold.into()], - )).await?; + [ + KeyState::Revoked.to_string().into(), + id.into(), + id.into(), + threshold.into(), + ], + )) + .await?; txn.commit().await?; Ok(()) } @@ -428,16 +704,22 @@ impl<'a> Repository for DataKeyRepository<'a> { async fn cancel_delete_key(&self, user_id: i32, id: i32) -> Result<()> { let txn = self.db_connection.begin().await?; //1. delete pending delete record - self.delete_pending_operation( - user_id, id, RequestType::Delete, &txn).await?; + self.delete_pending_operation(user_id, id, RequestType::Delete, &txn) + .await?; //2. update status if there is not any pending delete record. - let _: ExecResult = txn.execute(Statement::from_sql_and_values( - DatabaseBackend::MySql, - "UPDATE data_key SET key_state = ? \ + let _: ExecResult = txn + .execute(Statement::from_sql_and_values( + DatabaseBackend::MySql, + "UPDATE data_key SET key_state = ? \ WHERE id = ? AND ( \ SELECT COUNT(*) FROM pending_operation WHERE key_id = ?) = ?", - [KeyState::Disabled.to_string().into(), id.into(), id.into(), 0i32.into()], - )) + [ + KeyState::Disabled.to_string().into(), + id.into(), + id.into(), + 0i32.into(), + ], + )) .await?; txn.commit().await?; Ok(()) @@ -446,16 +728,23 @@ impl<'a> Repository for DataKeyRepository<'a> { async fn cancel_revoke_key(&self, user_id: i32, id: i32, parent_id: i32) -> Result<()> { let txn = self.db_connection.begin().await?; //1. delete pending delete record - self.delete_pending_operation(user_id, id, RequestType::Revoke, &txn).await?; + self.delete_pending_operation(user_id, id, RequestType::Revoke, &txn) + .await?; //2. delete revoked record self.delete_revoke_record(id, parent_id, &txn).await?; //3. update status if there is not any pending delete record. - let _: ExecResult = txn.execute(Statement::from_sql_and_values( + let _: ExecResult = txn + .execute(Statement::from_sql_and_values( DatabaseBackend::MySql, "UPDATE data_key SET key_state = ? \ WHERE id = ? AND ( \ SELECT COUNT(*) FROM pending_operation WHERE key_id = ?) = ?", - [KeyState::Disabled.to_string().into(), id.into(), id.into(), 0i32.into()], + [ + KeyState::Disabled.to_string().into(), + id.into(), + id.into(), + 0i32.into(), + ], )) .await?; txn.commit().await?; @@ -463,16 +752,13 @@ impl<'a> Repository for DataKeyRepository<'a> { } async fn get_x509_crl_by_ca_id(&self, id: i32) -> Result { - match crl_content_dto::Entity::find().filter( - crl_content_dto::Column::CaId.eq(id) - ).one( - self.db_connection).await? { - None => { - Err(Error::NotFoundError) - } - Some(content) => { - Ok(X509CRL::try_from(content)?) - } + match crl_content_dto::Entity::find() + .filter(crl_content_dto::Column::CaId.eq(id)) + .one(self.db_connection) + .await? + { + None => Err(Error::NotFoundError), + Some(content) => Ok(X509CRL::try_from(content)?), } } @@ -483,16 +769,20 @@ impl<'a> Repository for DataKeyRepository<'a> { ca_id: Set(crl.ca_id), data: Set(encode_u8_to_hex_string(&crl.data)), create_at: Set(crl.create_at), - update_at: Set(crl.update_at) + update_at: Set(crl.update_at), }; match self.get_x509_crl_by_ca_id(ca_id).await { Ok(_) => { //update crl content with new version - crl_content_dto::Entity::update(crl_model.clone()).filter( - crl_content_dto::Column::CaId.eq(ca_id)).exec(self.db_connection).await?; + crl_content_dto::Entity::update(crl_model.clone()) + .filter(crl_content_dto::Column::CaId.eq(ca_id)) + .exec(self.db_connection) + .await?; } Err(_) => { - crl_content_dto::Entity::insert(crl_model).exec(self.db_connection).await?; + crl_content_dto::Entity::insert(crl_model) + .exec(self.db_connection) + .await?; } }; Ok(()) @@ -501,17 +791,19 @@ impl<'a> Repository for DataKeyRepository<'a> { #[cfg(test)] mod tests { - use std::collections::HashMap; - use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction, TransactionTrait}; - use crate::domain::datakey::entity::{DataKey, KeyState, KeyType, ParentKey, RevokedKey, Visibility}; - use super::super::super::x509_revoked_key::dto as revoked_key_dto; - use chrono::{Duration}; use super::super::super::request_delete::dto as request_dto; + use super::super::super::x509_revoked_key::dto as revoked_key_dto; + use crate::domain::datakey::entity::{ + DataKey, KeyState, KeyType, ParentKey, RevokedKey, Visibility, + }; use crate::domain::datakey::repository::Repository; use crate::infra::database::model::datakey::dto; - use crate::util::error::Result; - use crate::infra::database::model::datakey::repository::{DataKeyRepository}; + use crate::infra::database::model::datakey::repository::DataKeyRepository; use crate::infra::database::model::request_delete::dto::RequestType; + use crate::util::error::Result; + use chrono::Duration; + use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction, TransactionTrait}; + use std::collections::HashMap; // unmark me when "num_items" issue fixed. // #[tokio::test] @@ -649,33 +941,32 @@ mod tests { async fn test_datakey_repository_get_by_id_sql_statement() -> Result<()> { let now = chrono::Utc::now(); let db = MockDatabase::new(DatabaseBackend::MySql) - .append_query_results([ - vec![dto::Model { - id: 1, - name: "Test Key".to_string(), - description: "".to_string(), - visibility: Visibility::Public.to_string(), - user: 0, - attributes: "{}".to_string(), - key_type: "pgp".to_string(), - parent_id: None, - fingerprint: "".to_string(), - serial_number: None, - private_key: "0708090A".to_string(), - public_key: "040506".to_string(), - certificate: "010203".to_string(), - create_at: now.clone(), - expire_at: now.clone(), - key_state: "disabled".to_string(), - user_email: None, - request_delete_users: None, - request_revoke_users: None, - x509_crl_update_at: None, - }], - ]).into_connection(); + .append_query_results([vec![dto::Model { + id: 1, + name: "Test Key".to_string(), + description: "".to_string(), + visibility: Visibility::Public.to_string(), + user: 0, + attributes: "{}".to_string(), + key_type: "pgp".to_string(), + parent_id: None, + fingerprint: "".to_string(), + serial_number: None, + private_key: "0708090A".to_string(), + public_key: "040506".to_string(), + certificate: "010203".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + key_state: "disabled".to_string(), + user_email: None, + request_delete_users: None, + request_revoke_users: None, + x509_crl_update_at: None, + }]]) + .into_connection(); let datakey_repository = DataKeyRepository::new(&db); - let user = DataKey{ + let user = DataKey { id: 1, name: "Test Key".to_string(), description: "".to_string(), @@ -686,9 +977,9 @@ mod tests { parent_id: None, fingerprint: "".to_string(), serial_number: None, - private_key: vec![7,8,9,10], - public_key: vec![4,5,6], - certificate: vec![1,2,3], + private_key: vec![7, 8, 9, 10], + public_key: vec![4, 5, 6], + certificate: vec![1, 2, 3], create_at: now.clone(), expire_at: now.clone(), key_state: KeyState::Disabled, @@ -698,17 +989,24 @@ mod tests { parent_key: None, }; assert_eq!( - datakey_repository.get_by_id_or_name(Some(1), None, false).await?, user + datakey_repository + .get_by_id_or_name(Some(1), None, false) + .await?, + user ); assert_eq!( db.into_transaction_log(), - [ - Transaction::from_sql_and_values( - DatabaseBackend::MySql, - r#"SELECT `data_key`.`id`, `data_key`.`name`, `data_key`.`description`, `data_key`.`visibility`, `data_key`.`user`, `data_key`.`attributes`, `data_key`.`key_type`, `data_key`.`parent_id`, `data_key`.`fingerprint`, `data_key`.`serial_number`, `data_key`.`private_key`, `data_key`.`public_key`, `data_key`.`certificate`, `data_key`.`create_at`, `data_key`.`expire_at`, `data_key`.`key_state`, user_table.email as user_email, GROUP_CONCAT(request_delete_table.user_email) as request_delete_users, GROUP_CONCAT(request_revoke_table.user_email) as request_revoke_users FROM `data_key` INNER JOIN `user` AS `user_table` ON `user_table`.`id` = `data_key`.`user` LEFT JOIN `pending_operation` AS `request_delete_table` ON `request_delete_table`.`key_id` = `data_key`.`id` AND `request_delete_table`.`request_type` = ? LEFT JOIN `pending_operation` AS `request_revoke_table` ON `request_revoke_table`.`key_id` = `data_key`.`id` AND `request_revoke_table`.`request_type` = ? WHERE `data_key`.`id` = ? AND `data_key`.`key_state` <> ? GROUP BY `data_key`.`id` LIMIT ?"#, - ["delete".into(), "revoke".into(), 1i32.into(), "deleted".into(), 1u64.into()] - ), - ] + [Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `data_key`.`id`, `data_key`.`name`, `data_key`.`description`, `data_key`.`visibility`, `data_key`.`user`, `data_key`.`attributes`, `data_key`.`key_type`, `data_key`.`parent_id`, `data_key`.`fingerprint`, `data_key`.`serial_number`, `data_key`.`private_key`, `data_key`.`public_key`, `data_key`.`certificate`, `data_key`.`create_at`, `data_key`.`expire_at`, `data_key`.`key_state`, user_table.email as user_email, GROUP_CONCAT(request_delete_table.user_email) as request_delete_users, GROUP_CONCAT(request_revoke_table.user_email) as request_revoke_users FROM `data_key` INNER JOIN `user` AS `user_table` ON `user_table`.`id` = `data_key`.`user` LEFT JOIN `pending_operation` AS `request_delete_table` ON `request_delete_table`.`key_id` = `data_key`.`id` AND `request_delete_table`.`request_type` = ? LEFT JOIN `pending_operation` AS `request_revoke_table` ON `request_revoke_table`.`key_id` = `data_key`.`id` AND `request_revoke_table`.`request_type` = ? WHERE `data_key`.`id` = ? AND `data_key`.`key_state` <> ? GROUP BY `data_key`.`id` LIMIT ?"#, + [ + "delete".into(), + "revoke".into(), + 1i32.into(), + "deleted".into(), + 1u64.into() + ] + ),] ); Ok(()) @@ -719,18 +1017,19 @@ mod tests { let now = chrono::Utc::now(); let db = MockDatabase::new(DatabaseBackend::MySql) .append_exec_results([ - MockExecResult{ + MockExecResult { last_insert_id: 1, rows_affected: 1, }, - MockExecResult{ + MockExecResult { last_insert_id: 1, rows_affected: 1, - } - ]).into_connection(); + }, + ]) + .into_connection(); let datakey_repository = DataKeyRepository::new(&db); - let datakey = DataKey{ + let datakey = DataKey { id: 1, name: "Test Key".to_string(), description: "".to_string(), @@ -741,9 +1040,9 @@ mod tests { parent_id: None, fingerprint: "456".to_string(), serial_number: Some("123".to_string()), - private_key: vec![7,8,9,10], - public_key: vec![4,5,6], - certificate: vec![1,2,3], + private_key: vec![7, 8, 9, 10], + public_key: vec![4, 5, 6], + certificate: vec![1, 2, 3], create_at: now.clone(), expire_at: now.clone(), key_state: KeyState::Disabled, @@ -752,20 +1051,37 @@ mod tests { request_revoke_users: None, parent_key: None, }; - assert_eq!(datakey_repository.update_state(1, KeyState::Enabled).await?,()); - assert_eq!(datakey_repository.update_key_data( datakey).await?,()); + assert_eq!( + datakey_repository + .update_state(1, KeyState::Enabled) + .await?, + () + ); + assert_eq!(datakey_repository.update_key_data(datakey).await?, ()); assert_eq!( db.into_transaction_log(), [ Transaction::from_sql_and_values( DatabaseBackend::MySql, r#"UPDATE `data_key` SET `key_state` = ? WHERE `data_key`.`id` = ? AND `data_key`.`key_state` <> ?"#, - [KeyState::Enabled.to_string().into(), 1i32.into(), "deleted".into()] + [ + KeyState::Enabled.to_string().into(), + 1i32.into(), + "deleted".into() + ] ), Transaction::from_sql_and_values( DatabaseBackend::MySql, r#"UPDATE `data_key` SET `serial_number` = ?, `fingerprint` = ?, `private_key` = ?, `public_key` = ?, `certificate` = ? WHERE `data_key`.`id` = ? AND `data_key`.`key_state` <> ?"#, - ["123".into(), "456".into(), "0708090A".into(), "040506".into(), "010203".into(), 1i32.into(), "deleted".into()] + [ + "123".into(), + "456".into(), + "0708090A".into(), + "040506".into(), + "010203".into(), + 1i32.into(), + "deleted".into() + ] ) ] ); @@ -777,33 +1093,32 @@ mod tests { async fn test_datakey_repository_get_keys_for_crl_update_sql_statement() -> Result<()> { let now = chrono::Utc::now(); let db = MockDatabase::new(DatabaseBackend::MySql) - .append_query_results([ - vec![dto::Model { - id: 1, - name: "Test Key".to_string(), - description: "".to_string(), - visibility: Visibility::Public.to_string(), - user: 0, - attributes: "{}".to_string(), - key_type: "pgp".to_string(), - parent_id: None, - fingerprint: "456".to_string(), - serial_number: Some("123".to_string()), - private_key: "0708090A".to_string(), - public_key: "040506".to_string(), - certificate: "010203".to_string(), - create_at: now.clone(), - expire_at: now.clone(), - key_state: "disabled".to_string(), - user_email: None, - request_delete_users: None, - request_revoke_users: None, - x509_crl_update_at: None, - }], - ]).into_connection(); + .append_query_results([vec![dto::Model { + id: 1, + name: "Test Key".to_string(), + description: "".to_string(), + visibility: Visibility::Public.to_string(), + user: 0, + attributes: "{}".to_string(), + key_type: "pgp".to_string(), + parent_id: None, + fingerprint: "456".to_string(), + serial_number: Some("123".to_string()), + private_key: "0708090A".to_string(), + public_key: "040506".to_string(), + certificate: "010203".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + key_state: "disabled".to_string(), + user_email: None, + request_delete_users: None, + request_revoke_users: None, + x509_crl_update_at: None, + }]]) + .into_connection(); let datakey_repository = DataKeyRepository::new(&db); - let datakey = DataKey{ + let datakey = DataKey { id: 1, name: "Test Key".to_string(), description: "".to_string(), @@ -814,9 +1129,9 @@ mod tests { parent_id: None, fingerprint: "456".to_string(), serial_number: Some("123".to_string()), - private_key: vec![7,8,9,10], - public_key: vec![4,5,6], - certificate: vec![1,2,3], + private_key: vec![7, 8, 9, 10], + public_key: vec![4, 5, 6], + certificate: vec![1, 2, 3], create_at: now.clone(), expire_at: now.clone(), key_state: KeyState::Disabled, @@ -826,16 +1141,21 @@ mod tests { parent_key: None, }; let duration = Duration::days(1); - assert_eq!(datakey_repository.get_keys_for_crl_update(duration).await?, vec![datakey]); + assert_eq!( + datakey_repository.get_keys_for_crl_update(duration).await?, + vec![datakey] + ); assert_eq!( db.into_transaction_log(), - [ - Transaction::from_sql_and_values( - DatabaseBackend::MySql, - r#"SELECT `data_key`.`id`, `data_key`.`name`, `data_key`.`description`, `data_key`.`visibility`, `data_key`.`user`, `data_key`.`attributes`, `data_key`.`key_type`, `data_key`.`parent_id`, `data_key`.`fingerprint`, `data_key`.`serial_number`, `data_key`.`private_key`, `data_key`.`public_key`, `data_key`.`certificate`, `data_key`.`create_at`, `data_key`.`expire_at`, `data_key`.`key_state`, `crl_table`.`update_at` AS `x509_crl_update_at` FROM `data_key` LEFT JOIN `x509_crl_content` AS `crl_table` ON `crl_table`.`ca_id` = `data_key`.`id` WHERE (`data_key`.`key_type` = ? OR `data_key`.`key_type` = ?) AND `data_key`.`key_state` <> ?"#, - [KeyType::X509CA.to_string().into(), KeyType::X509ICA.to_string().into(), "deleted".into()] - ) - ] + [Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `data_key`.`id`, `data_key`.`name`, `data_key`.`description`, `data_key`.`visibility`, `data_key`.`user`, `data_key`.`attributes`, `data_key`.`key_type`, `data_key`.`parent_id`, `data_key`.`fingerprint`, `data_key`.`serial_number`, `data_key`.`private_key`, `data_key`.`public_key`, `data_key`.`certificate`, `data_key`.`create_at`, `data_key`.`expire_at`, `data_key`.`key_state`, `crl_table`.`update_at` AS `x509_crl_update_at` FROM `data_key` LEFT JOIN `x509_crl_content` AS `crl_table` ON `crl_table`.`ca_id` = `data_key`.`id` WHERE (`data_key`.`key_type` = ? OR `data_key`.`key_type` = ?) AND `data_key`.`key_state` <> ?"#, + [ + KeyType::X509CA.to_string().into(), + KeyType::X509ICA.to_string().into(), + "deleted".into() + ] + )] ); Ok(()) @@ -845,19 +1165,18 @@ mod tests { async fn test_datakey_repository_get_revoked_serial_number_sql_statement() -> Result<()> { let now = chrono::Utc::now(); let db = MockDatabase::new(DatabaseBackend::MySql) - .append_query_results([ - vec![revoked_key_dto::Model { - id: 1, - key_id: 1, - ca_id: 1, - serial_number: Some("123".to_string()), - create_at: now.clone(), - reason: "unspecified".to_string(), - }], - ]).into_connection(); + .append_query_results([vec![revoked_key_dto::Model { + id: 1, + key_id: 1, + ca_id: 1, + serial_number: Some("123".to_string()), + create_at: now.clone(), + reason: "unspecified".to_string(), + }]]) + .into_connection(); let datakey_repository = DataKeyRepository::new(&db); - let revoked_key = revoked_key_dto::Model{ + let revoked_key = revoked_key_dto::Model { id: 1, key_id: 1, ca_id: 1, @@ -865,17 +1184,19 @@ mod tests { create_at: now.clone(), reason: "unspecified".to_string(), }; - assert_eq!(datakey_repository.get_revoked_serial_number_by_parent_id(1).await?, - vec![RevokedKey::try_from(revoked_key)?]); + assert_eq!( + datakey_repository + .get_revoked_serial_number_by_parent_id(1) + .await?, + vec![RevokedKey::try_from(revoked_key)?] + ); assert_eq!( db.into_transaction_log(), - [ - Transaction::from_sql_and_values( - DatabaseBackend::MySql, - r#"SELECT `x509_keys_revoked`.`id`, `x509_keys_revoked`.`key_id`, `x509_keys_revoked`.`ca_id`, `x509_keys_revoked`.`reason`, `x509_keys_revoked`.`create_at`, `datakey_table`.`serial_number` AS `serial_number` FROM `x509_keys_revoked` INNER JOIN `data_key` AS `datakey_table` ON `datakey_table`.`id` = `x509_keys_revoked`.`key_id` AND (`datakey_table`.`key_state` = ? AND `x509_keys_revoked`.`ca_id` = ?)"#, - [KeyState::Revoked.to_string().into(), 1i32.into()] - ) - ] + [Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `x509_keys_revoked`.`id`, `x509_keys_revoked`.`key_id`, `x509_keys_revoked`.`ca_id`, `x509_keys_revoked`.`reason`, `x509_keys_revoked`.`create_at`, `datakey_table`.`serial_number` AS `serial_number` FROM `x509_keys_revoked` INNER JOIN `data_key` AS `datakey_table` ON `datakey_table`.`id` = `x509_keys_revoked`.`key_id` AND (`datakey_table`.`key_state` = ? AND `x509_keys_revoked`.`ca_id` = ?)"#, + [KeyState::Revoked.to_string().into(), 1i32.into()] + )] ); Ok(()) } @@ -884,33 +1205,32 @@ mod tests { async fn test_datakey_get_raw_key_sql_statement() -> Result<()> { let now = chrono::Utc::now(); let db = MockDatabase::new(DatabaseBackend::MySql) - .append_query_results([ - vec![dto::Model { - id: 1, - name: "Test Key".to_string(), - description: "".to_string(), - visibility: Visibility::Public.to_string(), - user: 0, - attributes: "{}".to_string(), - key_type: "pgp".to_string(), - parent_id: None, - fingerprint: "".to_string(), - serial_number: None, - private_key: "0708090A".to_string(), - public_key: "040506".to_string(), - certificate: "010203".to_string(), - create_at: now.clone(), - expire_at: now.clone(), - key_state: "disabled".to_string(), - user_email: None, - request_delete_users: None, - request_revoke_users: None, - x509_crl_update_at: None, - }], - ]).into_connection(); + .append_query_results([vec![dto::Model { + id: 1, + name: "Test Key".to_string(), + description: "".to_string(), + visibility: Visibility::Public.to_string(), + user: 0, + attributes: "{}".to_string(), + key_type: "pgp".to_string(), + parent_id: None, + fingerprint: "".to_string(), + serial_number: None, + private_key: "0708090A".to_string(), + public_key: "040506".to_string(), + certificate: "010203".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + key_state: "disabled".to_string(), + user_email: None, + request_delete_users: None, + request_revoke_users: None, + x509_crl_update_at: None, + }]]) + .into_connection(); let datakey_repository = DataKeyRepository::new(&db); - let user = DataKey{ + let user = DataKey { id: 1, name: "Test Key".to_string(), description: "".to_string(), @@ -921,9 +1241,9 @@ mod tests { parent_id: None, fingerprint: "".to_string(), serial_number: None, - private_key: vec![7,8,9,10], - public_key: vec![4,5,6], - certificate: vec![1,2,3], + private_key: vec![7, 8, 9, 10], + public_key: vec![4, 5, 6], + certificate: vec![1, 2, 3], create_at: now.clone(), expire_at: now.clone(), key_state: KeyState::Disabled, @@ -933,17 +1253,18 @@ mod tests { parent_key: None, }; assert_eq!( - datakey_repository.get_by_id_or_name(None, Some("Test Key".to_string()), true).await?, user + datakey_repository + .get_by_id_or_name(None, Some("Test Key".to_string()), true) + .await?, + user ); assert_eq!( db.into_transaction_log(), - [ - Transaction::from_sql_and_values( - DatabaseBackend::MySql, - r#"SELECT `data_key`.`id`, `data_key`.`name`, `data_key`.`description`, `data_key`.`visibility`, `data_key`.`user`, `data_key`.`attributes`, `data_key`.`key_type`, `data_key`.`parent_id`, `data_key`.`fingerprint`, `data_key`.`serial_number`, `data_key`.`private_key`, `data_key`.`public_key`, `data_key`.`certificate`, `data_key`.`create_at`, `data_key`.`expire_at`, `data_key`.`key_state` FROM `data_key` WHERE `data_key`.`name` = ? AND `data_key`.`key_state` <> ? LIMIT ?"#, - ["Test Key".into(), "deleted".into(), 1u64.into()] - ), - ] + [Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `data_key`.`id`, `data_key`.`name`, `data_key`.`description`, `data_key`.`visibility`, `data_key`.`user`, `data_key`.`attributes`, `data_key`.`key_type`, `data_key`.`parent_id`, `data_key`.`fingerprint`, `data_key`.`serial_number`, `data_key`.`private_key`, `data_key`.`public_key`, `data_key`.`certificate`, `data_key`.`create_at`, `data_key`.`expire_at`, `data_key`.`key_state` FROM `data_key` WHERE `data_key`.`name` = ? AND `data_key`.`key_state` <> ? LIMIT ?"#, + ["Test Key".into(), "deleted".into(), 1u64.into()] + ),] ); Ok(()) @@ -953,33 +1274,32 @@ mod tests { async fn test_datakey_repository_get_by_name_sql_statement() -> Result<()> { let now = chrono::Utc::now(); let db = MockDatabase::new(DatabaseBackend::MySql) - .append_query_results([ - vec![dto::Model { - id: 1, - name: "Test Key".to_string(), - description: "".to_string(), - visibility: Visibility::Public.to_string(), - user: 0, - attributes: "{}".to_string(), - key_type: "pgp".to_string(), - parent_id: None, - fingerprint: "".to_string(), - serial_number: None, - private_key: "0708090A".to_string(), - public_key: "040506".to_string(), - certificate: "010203".to_string(), - create_at: now.clone(), - expire_at: now.clone(), - key_state: "disabled".to_string(), - user_email: None, - request_delete_users: None, - request_revoke_users: None, - x509_crl_update_at: None, - }], - ]).into_connection(); + .append_query_results([vec![dto::Model { + id: 1, + name: "Test Key".to_string(), + description: "".to_string(), + visibility: Visibility::Public.to_string(), + user: 0, + attributes: "{}".to_string(), + key_type: "pgp".to_string(), + parent_id: None, + fingerprint: "".to_string(), + serial_number: None, + private_key: "0708090A".to_string(), + public_key: "040506".to_string(), + certificate: "010203".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + key_state: "disabled".to_string(), + user_email: None, + request_delete_users: None, + request_revoke_users: None, + x509_crl_update_at: None, + }]]) + .into_connection(); let datakey_repository = DataKeyRepository::new(&db); - let user = DataKey{ + let user = DataKey { id: 1, name: "Test Key".to_string(), description: "".to_string(), @@ -990,9 +1310,9 @@ mod tests { parent_id: None, fingerprint: "".to_string(), serial_number: None, - private_key: vec![7,8,9,10], - public_key: vec![4,5,6], - certificate: vec![1,2,3], + private_key: vec![7, 8, 9, 10], + public_key: vec![4, 5, 6], + certificate: vec![1, 2, 3], create_at: now.clone(), expire_at: now.clone(), key_state: KeyState::Disabled, @@ -1002,17 +1322,24 @@ mod tests { parent_key: None, }; assert_eq!( - datakey_repository.get_by_id_or_name(None, Some("Test Key".to_string()), false).await?, user + datakey_repository + .get_by_id_or_name(None, Some("Test Key".to_string()), false) + .await?, + user ); assert_eq!( db.into_transaction_log(), - [ - Transaction::from_sql_and_values( - DatabaseBackend::MySql, - r#"SELECT `data_key`.`id`, `data_key`.`name`, `data_key`.`description`, `data_key`.`visibility`, `data_key`.`user`, `data_key`.`attributes`, `data_key`.`key_type`, `data_key`.`parent_id`, `data_key`.`fingerprint`, `data_key`.`serial_number`, `data_key`.`private_key`, `data_key`.`public_key`, `data_key`.`certificate`, `data_key`.`create_at`, `data_key`.`expire_at`, `data_key`.`key_state`, user_table.email as user_email, GROUP_CONCAT(request_delete_table.user_email) as request_delete_users, GROUP_CONCAT(request_revoke_table.user_email) as request_revoke_users FROM `data_key` INNER JOIN `user` AS `user_table` ON `user_table`.`id` = `data_key`.`user` LEFT JOIN `pending_operation` AS `request_delete_table` ON `request_delete_table`.`key_id` = `data_key`.`id` AND `request_delete_table`.`request_type` = ? LEFT JOIN `pending_operation` AS `request_revoke_table` ON `request_revoke_table`.`key_id` = `data_key`.`id` AND `request_revoke_table`.`request_type` = ? WHERE `data_key`.`name` = ? AND `data_key`.`key_state` <> ? GROUP BY `data_key`.`id` LIMIT ?"#, - ["delete".into(), "revoke".into(), "Test Key".into(), "deleted".into(), 1u64.into()] - ), - ] + [Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `data_key`.`id`, `data_key`.`name`, `data_key`.`description`, `data_key`.`visibility`, `data_key`.`user`, `data_key`.`attributes`, `data_key`.`key_type`, `data_key`.`parent_id`, `data_key`.`fingerprint`, `data_key`.`serial_number`, `data_key`.`private_key`, `data_key`.`public_key`, `data_key`.`certificate`, `data_key`.`create_at`, `data_key`.`expire_at`, `data_key`.`key_state`, user_table.email as user_email, GROUP_CONCAT(request_delete_table.user_email) as request_delete_users, GROUP_CONCAT(request_revoke_table.user_email) as request_revoke_users FROM `data_key` INNER JOIN `user` AS `user_table` ON `user_table`.`id` = `data_key`.`user` LEFT JOIN `pending_operation` AS `request_delete_table` ON `request_delete_table`.`key_id` = `data_key`.`id` AND `request_delete_table`.`request_type` = ? LEFT JOIN `pending_operation` AS `request_revoke_table` ON `request_revoke_table`.`key_id` = `data_key`.`id` AND `request_revoke_table`.`request_type` = ? WHERE `data_key`.`name` = ? AND `data_key`.`key_state` <> ? GROUP BY `data_key`.`id` LIMIT ?"#, + [ + "delete".into(), + "revoke".into(), + "Test Key".into(), + "deleted".into(), + 1u64.into() + ] + ),] ); Ok(()) @@ -1021,26 +1348,21 @@ mod tests { #[tokio::test] async fn test_datakey_repository_delete_datakey_sql_statement() -> Result<()> { let db = MockDatabase::new(DatabaseBackend::MySql) - .append_exec_results([ - MockExecResult { - last_insert_id: 0, - rows_affected: 0, - } - ]).into_connection(); + .append_exec_results([MockExecResult { + last_insert_id: 0, + rows_affected: 0, + }]) + .into_connection(); let datakey_repository = DataKeyRepository::new(&db); - assert_eq!( - datakey_repository.delete(1).await?, () - ); + assert_eq!(datakey_repository.delete(1).await?, ()); assert_eq!( db.into_transaction_log(), - [ - Transaction::from_sql_and_values( - DatabaseBackend::MySql, - r#"DELETE FROM `data_key` WHERE `data_key`.`id` = ?"#, - [1i32.into()] - ), - ] + [Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"DELETE FROM `data_key` WHERE `data_key`.`id` = ?"#, + [1i32.into()] + ),] ); Ok(()) @@ -1095,10 +1417,11 @@ mod tests { request_revoke_users: None, x509_crl_update_at: None, }], - ]).into_connection(); + ]) + .into_connection(); let datakey_repository = DataKeyRepository::new(&db); - let user = DataKey{ + let user = DataKey { id: 1, name: "Test Key".to_string(), description: "".to_string(), @@ -1109,25 +1432,31 @@ mod tests { parent_id: Some(2), fingerprint: "".to_string(), serial_number: None, - private_key: vec![7,8,9,10], - public_key: vec![4,5,6], - certificate: vec![1,2,3], + private_key: vec![7, 8, 9, 10], + public_key: vec![4, 5, 6], + certificate: vec![1, 2, 3], create_at: now.clone(), expire_at: now.clone(), key_state: KeyState::Disabled, user_email: None, request_delete_users: None, request_revoke_users: None, - parent_key: Some(ParentKey{ + parent_key: Some(ParentKey { name: "Parent Key".to_string(), attributes: HashMap::new(), - private_key: vec![7,8,9,10], - public_key: vec![4,5,6], - certificate: vec![1,2,3], + private_key: vec![7, 8, 9, 10], + public_key: vec![4, 5, 6], + certificate: vec![1, 2, 3], }), }; assert_eq!( - datakey_repository.get_enabled_key_by_type_and_name_with_parent_key("openpgp".to_string(), "fake_name".to_string()).await?, user + datakey_repository + .get_enabled_key_by_type_and_name_with_parent_key( + "openpgp".to_string(), + "fake_name".to_string() + ) + .await?, + user ); assert_eq!( db.into_transaction_log(), @@ -1135,7 +1464,12 @@ mod tests { Transaction::from_sql_and_values( DatabaseBackend::MySql, r#"SELECT `data_key`.`id`, `data_key`.`name`, `data_key`.`description`, `data_key`.`visibility`, `data_key`.`user`, `data_key`.`attributes`, `data_key`.`key_type`, `data_key`.`parent_id`, `data_key`.`fingerprint`, `data_key`.`serial_number`, `data_key`.`private_key`, `data_key`.`public_key`, `data_key`.`certificate`, `data_key`.`create_at`, `data_key`.`expire_at`, `data_key`.`key_state` FROM `data_key` WHERE `data_key`.`name` = ? AND `data_key`.`key_type` = ? AND `data_key`.`key_state` = ? LIMIT ?"#, - ["fake_name".into(), "openpgp".into(), "enabled".into(), 1u64.into()] + [ + "fake_name".into(), + "openpgp".into(), + "enabled".into(), + 1u64.into() + ] ), Transaction::from_sql_and_values( DatabaseBackend::MySql, @@ -1197,15 +1531,15 @@ mod tests { request_revoke_users: None, x509_crl_update_at: None, }], - ]).append_exec_results( - [ MockExecResult { + ]) + .append_exec_results([MockExecResult { last_insert_id: 1, rows_affected: 1, - }] - ).into_connection(); + }]) + .into_connection(); let datakey_repository = DataKeyRepository::new(&db); - let datakey = DataKey{ + let datakey = DataKey { id: 1, name: "Test Key".to_string(), description: "".to_string(), @@ -1216,33 +1550,48 @@ mod tests { parent_id: Some(2), fingerprint: "".to_string(), serial_number: Some("123".to_string()), - private_key: vec![7,8,9,10], - public_key: vec![4,5,6], - certificate: vec![1,2,3], + private_key: vec![7, 8, 9, 10], + public_key: vec![4, 5, 6], + certificate: vec![1, 2, 3], create_at: now.clone(), expire_at: now.clone(), key_state: KeyState::Disabled, user_email: None, request_delete_users: None, request_revoke_users: None, - parent_key: Some(ParentKey{ + parent_key: Some(ParentKey { name: "Test Parent Key".to_string(), - private_key: vec![7,8,9,10], - public_key: vec![4,5,6], - certificate: vec![1,2,3], + private_key: vec![7, 8, 9, 10], + public_key: vec![4, 5, 6], + certificate: vec![1, 2, 3], attributes: HashMap::new(), }), }; - assert_eq!( - datakey_repository.create(datakey.clone()).await?, datakey - ); + assert_eq!(datakey_repository.create(datakey.clone()).await?, datakey); assert_eq!( db.into_transaction_log(), [ Transaction::from_sql_and_values( DatabaseBackend::MySql, r#"INSERT INTO `data_key` (`id`, `name`, `description`, `visibility`, `user`, `attributes`, `key_type`, `parent_id`, `fingerprint`, `serial_number`, `private_key`, `public_key`, `certificate`, `create_at`, `expire_at`, `key_state`) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"#, - [1i32.into(), "Test Key".into(), "".into(), "public".into(), 0i32.into(), "{}".into(), "pgp".into(), 2i32.into(), "".into(), "123".into(), "0708090A".into(), "040506".into(), "010203".into(), now.clone().into(), now.clone().into(), "disabled".into()] + [ + 1i32.into(), + "Test Key".into(), + "".into(), + "public".into(), + 0i32.into(), + "{}".into(), + "pgp".into(), + 2i32.into(), + "".into(), + "123".into(), + "0708090A".into(), + "040506".into(), + "010203".into(), + now.clone().into(), + now.clone().into(), + "disabled".into() + ] ), Transaction::from_sql_and_values( DatabaseBackend::MySql, @@ -1263,8 +1612,8 @@ mod tests { async fn test_datakey_repository_get_keys_by_parent_id_sql_statement() -> Result<()> { let now = chrono::Utc::now(); let db = MockDatabase::new(DatabaseBackend::MySql) - .append_query_results([ - vec![dto::Model { + .append_query_results([vec![ + dto::Model { id: 1, name: "Test Key".to_string(), description: "".to_string(), @@ -1285,7 +1634,8 @@ mod tests { request_delete_users: None, request_revoke_users: None, x509_crl_update_at: None, - }, dto::Model { + }, + dto::Model { id: 2, name: "Test Key2".to_string(), description: "".to_string(), @@ -1306,11 +1656,12 @@ mod tests { request_delete_users: None, request_revoke_users: None, x509_crl_update_at: None, - }], - ]).into_connection(); + }, + ]]) + .into_connection(); let datakey_repository = DataKeyRepository::new(&db); - let datakey1 = DataKey{ + let datakey1 = DataKey { id: 1, name: "Test Key".to_string(), description: "".to_string(), @@ -1321,9 +1672,9 @@ mod tests { parent_id: None, fingerprint: "".to_string(), serial_number: None, - private_key: vec![7,8,9,10], - public_key: vec![4,5,6], - certificate: vec![1,2,3], + private_key: vec![7, 8, 9, 10], + public_key: vec![4, 5, 6], + certificate: vec![1, 2, 3], create_at: now.clone(), expire_at: now.clone(), key_state: KeyState::Disabled, @@ -1332,7 +1683,7 @@ mod tests { request_revoke_users: None, parent_key: None, }; - let datakey2 = DataKey{ + let datakey2 = DataKey { id: 2, name: "Test Key2".to_string(), description: "".to_string(), @@ -1343,9 +1694,9 @@ mod tests { parent_id: None, fingerprint: "".to_string(), serial_number: None, - private_key: vec![7,8,9,10], - public_key: vec![4,5,6], - certificate: vec![1,2,3], + private_key: vec![7, 8, 9, 10], + public_key: vec![4, 5, 6], + certificate: vec![1, 2, 3], create_at: now.clone(), expire_at: now.clone(), key_state: KeyState::Disabled, @@ -1355,17 +1706,21 @@ mod tests { parent_key: None, }; assert_eq!( - datakey_repository.get_by_parent_id(1).await?, vec![datakey1, datakey2] + datakey_repository.get_by_parent_id(1).await?, + vec![datakey1, datakey2] ); assert_eq!( db.into_transaction_log(), - [ - Transaction::from_sql_and_values( - DatabaseBackend::MySql, - r#"SELECT `data_key`.`id`, `data_key`.`name`, `data_key`.`description`, `data_key`.`visibility`, `data_key`.`user`, `data_key`.`attributes`, `data_key`.`key_type`, `data_key`.`parent_id`, `data_key`.`fingerprint`, `data_key`.`serial_number`, `data_key`.`private_key`, `data_key`.`public_key`, `data_key`.`certificate`, `data_key`.`create_at`, `data_key`.`expire_at`, `data_key`.`key_state`, user_table.email as user_email, GROUP_CONCAT(request_delete_table.user_email) as request_delete_users, GROUP_CONCAT(request_revoke_table.user_email) as request_revoke_users FROM `data_key` INNER JOIN `user` AS `user_table` ON `user_table`.`id` = `data_key`.`user` LEFT JOIN `pending_operation` AS `request_delete_table` ON `request_delete_table`.`key_id` = `data_key`.`id` AND `request_delete_table`.`request_type` = ? LEFT JOIN `pending_operation` AS `request_revoke_table` ON `request_revoke_table`.`key_id` = `data_key`.`id` AND `request_revoke_table`.`request_type` = ? WHERE `data_key`.`parent_id` = ? AND `data_key`.`key_state` <> ? GROUP BY `data_key`.`id`"#, - ["delete".into(), "revoke".into(), 1i32.into(), "deleted".into()] - ), - ] + [Transaction::from_sql_and_values( + DatabaseBackend::MySql, + r#"SELECT `data_key`.`id`, `data_key`.`name`, `data_key`.`description`, `data_key`.`visibility`, `data_key`.`user`, `data_key`.`attributes`, `data_key`.`key_type`, `data_key`.`parent_id`, `data_key`.`fingerprint`, `data_key`.`serial_number`, `data_key`.`private_key`, `data_key`.`public_key`, `data_key`.`certificate`, `data_key`.`create_at`, `data_key`.`expire_at`, `data_key`.`key_state`, user_table.email as user_email, GROUP_CONCAT(request_delete_table.user_email) as request_delete_users, GROUP_CONCAT(request_revoke_table.user_email) as request_revoke_users FROM `data_key` INNER JOIN `user` AS `user_table` ON `user_table`.`id` = `data_key`.`user` LEFT JOIN `pending_operation` AS `request_delete_table` ON `request_delete_table`.`key_id` = `data_key`.`id` AND `request_delete_table`.`request_type` = ? LEFT JOIN `pending_operation` AS `request_revoke_table` ON `request_revoke_table`.`key_id` = `data_key`.`id` AND `request_revoke_table`.`request_type` = ? WHERE `data_key`.`parent_id` = ? AND `data_key`.`key_state` <> ? GROUP BY `data_key`.`id`"#, + [ + "delete".into(), + "revoke".into(), + 1i32.into(), + "deleted".into() + ] + ),] ); Ok(()) } @@ -1374,24 +1729,30 @@ mod tests { async fn test_datakey_repository_create_delete_pending_operation_sql_statement() -> Result<()> { let now = chrono::Utc::now(); let db = MockDatabase::new(DatabaseBackend::MySql) - .append_exec_results([ - MockExecResult{ - last_insert_id: 1, - rows_affected: 1, - } - ]).into_connection(); + .append_exec_results([MockExecResult { + last_insert_id: 1, + rows_affected: 1, + }]) + .into_connection(); let datakey_repository = DataKeyRepository::new(&db); let mut tx = db.begin().await?; assert_eq!( - datakey_repository.create_pending_operation(request_dto::Model { - id: 0, - user_id: 1, - key_id: 1, - request_type: RequestType::Delete.to_string(), - user_email: "fake_email".to_string(), - create_at: now, - }, &mut tx).await?, ()); + datakey_repository + .create_pending_operation( + request_dto::Model { + id: 0, + user_id: 1, + key_id: 1, + request_type: RequestType::Delete.to_string(), + user_email: "fake_email".to_string(), + create_at: now, + }, + &mut tx + ) + .await?, + () + ); tx.commit().await?; //TODO 1.Now mock database begin statement is configured with postgres backend, enabled this when fixed in upstream // assert_eq!( @@ -1418,6 +1779,4 @@ mod tests { // ); Ok(()) } - } - diff --git a/src/infra/database/model/mod.rs b/src/infra/database/model/mod.rs index 4953ee7..94c97e6 100644 --- a/src/infra/database/model/mod.rs +++ b/src/infra/database/model/mod.rs @@ -1,7 +1,7 @@ pub mod clusterkey; pub mod datakey; -pub mod user; -pub mod token; pub mod request_delete; +pub mod token; +pub mod user; +pub mod x509_crl_content; pub mod x509_revoked_key; -pub mod x509_crl_content; \ No newline at end of file diff --git a/src/infra/database/model/request_delete/dto.rs b/src/infra/database/model/request_delete/dto.rs index 80426ff..506a4c1 100644 --- a/src/infra/database/model/request_delete/dto.rs +++ b/src/infra/database/model/request_delete/dto.rs @@ -13,14 +13,14 @@ * * // See the Mulan PSL v2 for more details. * */ +use crate::util::error::Error; +use chrono::{DateTime, Utc}; use std::fmt::{Display, Formatter}; use std::str::FromStr; -use chrono::{DateTime, Utc}; -use crate::util::error::Error; -use sqlx::types::chrono; use sea_orm::entity::prelude::*; use serde::{Deserialize, Serialize}; +use sqlx::types::chrono; #[derive(Debug, Clone, PartialEq, sqlx::Type)] pub enum RequestType { @@ -46,7 +46,7 @@ impl FromStr for RequestType { match s { "delete" => Ok(RequestType::Delete), "revoke" => Ok(RequestType::Revoke), - _ => Err(Error::UnsupportedTypeError(s.to_string())) + _ => Err(Error::UnsupportedTypeError(s.to_string())), } } } @@ -68,7 +68,6 @@ pub enum Relation {} impl ActiveModelBehavior for ActiveModel {} - impl Model { pub fn new_for_delete(key_id: i32, user_id: i32, user_email: String) -> Self { Self { diff --git a/src/infra/database/model/token/dto.rs b/src/infra/database/model/token/dto.rs index 0b3f5b0..2a0aa5f 100644 --- a/src/infra/database/model/token/dto.rs +++ b/src/infra/database/model/token/dto.rs @@ -44,7 +44,7 @@ impl From for Token { description: dto.description.clone(), token: dto.token.clone(), create_at: dto.create_at, - expire_at:dto.expire_at, + expire_at: dto.expire_at, } } } @@ -62,7 +62,7 @@ mod tests { description: "Test token".to_string(), token: "hashedtoken".to_string(), create_at: now, - expire_at: now + chrono::Duration::days(1) + expire_at: now + chrono::Duration::days(1), }; let token = Token::from(dto.clone()); assert_eq!(token.id, dto.id); @@ -73,4 +73,3 @@ mod tests { assert_eq!(token.expire_at, dto.expire_at); } } - diff --git a/src/infra/database/model/token/repository.rs b/src/infra/database/model/token/repository.rs index aaa5efc..48eca05 100644 --- a/src/infra/database/model/token/repository.rs +++ b/src/infra/database/model/token/repository.rs @@ -14,27 +14,27 @@ * */ -use crate::domain::token::entity::{Token}; +use crate::domain::token::entity::Token; use crate::domain::token::repository::Repository; -use crate::util::error::Result; -use async_trait::async_trait; -use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, ActiveValue::Set, Condition, ActiveModelTrait}; use crate::infra::database::model::token; use crate::infra::database::model::token::dto::Entity as TokenDTO; use crate::util::error; +use crate::util::error::Result; use crate::util::key::get_token_hash; - +use async_trait::async_trait; +use sea_orm::{ + ActiveModelTrait, ActiveValue::Set, ColumnTrait, Condition, DatabaseConnection, EntityTrait, + QueryFilter, +}; #[derive(Clone)] pub struct TokenRepository<'a> { - db_connection: &'a DatabaseConnection + db_connection: &'a DatabaseConnection, } impl<'a> TokenRepository<'a> { pub fn new(db_connection: &'a DatabaseConnection) -> Self { - Self { - db_connection, - } + Self { db_connection } } } @@ -45,7 +45,7 @@ impl<'a> Repository for TokenRepository<'a> { user_id: Set(token.user_id), description: Set(token.description), token: Set(get_token_hash(&token.token)), - create_at:Set(token.create_at), + create_at: Set(token.create_at), expire_at: Set(token.expire_at), ..Default::default() }; @@ -54,39 +54,39 @@ impl<'a> Repository for TokenRepository<'a> { async fn get_token_by_id(&self, id: i32) -> Result { match TokenDTO::find_by_id(id).one(self.db_connection).await? { - None => { - Err(error::Error::NotFoundError) - } - Some(token) => { - Ok(Token::from(token)) - } + None => Err(error::Error::NotFoundError), + Some(token) => Ok(Token::from(token)), } } async fn get_token_by_value(&self, token: &str) -> Result { - match TokenDTO::find().filter( - token::dto::Column::Token.eq(get_token_hash(token))).one( - self.db_connection).await? { - None => { - Err(error::Error::NotFoundError) - } - Some(token) => { - Ok(Token::from(token)) - } + match TokenDTO::find() + .filter(token::dto::Column::Token.eq(get_token_hash(token))) + .one(self.db_connection) + .await? + { + None => Err(error::Error::NotFoundError), + Some(token) => Ok(Token::from(token)), } } async fn delete_by_user_and_id(&self, id: i32, user_id: i32) -> Result<()> { - let _ = TokenDTO::delete_many().filter(Condition::all() - .add(token::dto::Column::Id.eq(id)) - .add(token::dto::Column::UserId.eq(user_id))).exec(self.db_connection) + let _ = TokenDTO::delete_many() + .filter( + Condition::all() + .add(token::dto::Column::Id.eq(id)) + .add(token::dto::Column::UserId.eq(user_id)), + ) + .exec(self.db_connection) .await?; Ok(()) } async fn get_token_by_user_id(&self, id: i32) -> Result> { - let tokens = TokenDTO::find().filter( - token::dto::Column::UserId.eq(id)).all(self.db_connection).await?; + let tokens = TokenDTO::find() + .filter(token::dto::Column::UserId.eq(id)) + .all(self.db_connection) + .await?; let mut results = vec![]; for dto in tokens.into_iter() { results.push(Token::from(dto)); @@ -97,36 +97,34 @@ impl<'a> Repository for TokenRepository<'a> { #[cfg(test)] mod tests { - use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction}; use crate::domain::token::entity::Token; use crate::domain::token::repository::Repository; use crate::infra::database::model::token::dto; + use crate::infra::database::model::token::repository::TokenRepository; use crate::util::error::Result; - use crate::infra::database::model::token::repository::{TokenRepository}; use crate::util::key::get_token_hash; + use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction}; #[tokio::test] async fn test_token_repository_create_sql_statement() -> Result<()> { let now = chrono::Utc::now(); let db = MockDatabase::new(DatabaseBackend::MySql) - .append_query_results([ - vec![dto::Model { - id: 1, - user_id: 0, - description: "fake_token".to_string(), - token: "random_number".to_string(), - create_at: now.clone(), - expire_at: now.clone(), - }], - ]).append_exec_results([ - MockExecResult{ + .append_query_results([vec![dto::Model { + id: 1, + user_id: 0, + description: "fake_token".to_string(), + token: "random_number".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }]]) + .append_exec_results([MockExecResult { last_insert_id: 1, rows_affected: 1, - } - ]).into_connection(); + }]) + .into_connection(); let token_repository = TokenRepository::new(&db); - let user = Token{ + let user = Token { id: 1, user_id: 0, description: "fake_token".to_string(), @@ -153,7 +151,13 @@ mod tests { Transaction::from_sql_and_values( DatabaseBackend::MySql, r#"INSERT INTO `token` (`user_id`, `description`, `token`, `create_at`, `expire_at`) VALUES (?, ?, ?, ?, ?)"#, - [0i32.into(), "fake_token".into(), hashed_token.into(), now.clone().into(), now.clone().into()] + [ + 0i32.into(), + "fake_token".into(), + hashed_token.into(), + now.clone().into(), + now.clone().into() + ] ), Transaction::from_sql_and_values( DatabaseBackend::MySql, @@ -170,21 +174,19 @@ mod tests { async fn test_token_repository_delete_sql_statement() -> Result<()> { let now = chrono::Utc::now(); let db = MockDatabase::new(DatabaseBackend::MySql) - .append_query_results([ - vec![dto::Model { - id: 1, - user_id: 0, - description: "fake_token".to_string(), - token: "random_number".to_string(), - create_at: now.clone(), - expire_at: now.clone(), - }], - ]).append_exec_results([ - MockExecResult{ + .append_query_results([vec![dto::Model { + id: 1, + user_id: 0, + description: "fake_token".to_string(), + token: "random_number".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }]]) + .append_exec_results([MockExecResult { last_insert_id: 1, rows_affected: 1, - } - ]).into_connection(); + }]) + .into_connection(); let token_repository = TokenRepository::new(&db); assert_eq!(token_repository.delete_by_user_and_id(1, 1).await?, ()); @@ -224,22 +226,26 @@ mod tests { create_at: now.clone(), expire_at: now.clone(), }], - vec![dto::Model { - id: 1, - user_id: 0, - description: "fake_token".to_string(), - token: "random_number".to_string(), - create_at: now.clone(), - expire_at: now.clone(), - }, dto::Model { - id: 2, - user_id: 0, - description: "fake_token2".to_string(), - token: "random_number2".to_string(), - create_at: now.clone(), - expire_at: now.clone(), - }], - ]).into_connection(); + vec![ + dto::Model { + id: 1, + user_id: 0, + description: "fake_token".to_string(), + token: "random_number".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }, + dto::Model { + id: 2, + user_id: 0, + description: "fake_token2".to_string(), + token: "random_number2".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }, + ], + ]) + .into_connection(); let token_repository = TokenRepository::new(&db); assert_eq!( @@ -268,22 +274,23 @@ mod tests { assert_eq!( token_repository.get_token_by_user_id(0).await?, vec![ - Token::from(dto::Model { - id: 1, - user_id: 0, - description: "fake_token".to_string(), - token: "random_number".to_string(), - create_at: now.clone(), - expire_at: now.clone(), - }), Token::from(dto::Model { - id: 2, - user_id: 0, - description: "fake_token2".to_string(), - token: "random_number2".to_string(), - create_at: now.clone(), - expire_at: now.clone(), - })] + id: 1, + user_id: 0, + description: "fake_token".to_string(), + token: "random_number".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }), + Token::from(dto::Model { + id: 2, + user_id: 0, + description: "fake_token2".to_string(), + token: "random_number2".to_string(), + create_at: now.clone(), + expire_at: now.clone(), + }) + ] ); let hashed_token = get_token_hash("fake_content"); diff --git a/src/infra/database/model/user/dto.rs b/src/infra/database/model/user/dto.rs index a80cf18..b3f1ce3 100644 --- a/src/infra/database/model/user/dto.rs +++ b/src/infra/database/model/user/dto.rs @@ -22,15 +22,15 @@ use serde::{Deserialize, Serialize}; pub struct Model { #[sea_orm(primary_key)] pub id: i32, - pub email: String + pub email: String, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] pub enum Relation { #[sea_orm( - belongs_to = "super::super::datakey::dto::Entity", - from = "Column::Id", - to = "super::super::datakey::dto::Column::User" + belongs_to = "super::super::datakey::dto::Entity", + from = "Column::Id", + to = "super::super::datakey::dto::Column::User" )] Datakey, } @@ -48,7 +48,7 @@ impl From for User { fn from(dto: Model) -> Self { Self { id: dto.id, - email: dto.email + email: dto.email, } } } diff --git a/src/infra/database/model/user/repository.rs b/src/infra/database/model/user/repository.rs index f6ecaf8..55ba6bd 100644 --- a/src/infra/database/model/user/repository.rs +++ b/src/infra/database/model/user/repository.rs @@ -15,34 +15,31 @@ */ use super::dto::Entity as UserDTO; -use crate::infra::database::model::user; -use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, ActiveValue::Set, ActiveModelTrait}; use crate::domain::user::entity::User; use crate::domain::user::repository::Repository; +use crate::infra::database::model::user; use crate::util::error::{Error, Result}; use async_trait::async_trait; +use sea_orm::{ + ActiveModelTrait, ActiveValue::Set, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, +}; #[derive(Clone)] pub struct UserRepository<'a> { - db_connection: &'a DatabaseConnection + db_connection: &'a DatabaseConnection, } impl<'a> UserRepository<'a> { pub fn new(db_connection: &'a DatabaseConnection) -> Self { - Self { - db_connection, - } + Self { db_connection } } } #[async_trait] impl<'a> Repository for UserRepository<'a> { - async fn create(&self, user: User) -> Result { return match self.get_by_email(&user.email).await { - Ok(existed) => { - Ok(existed) - } + Ok(existed) => Ok(existed), Err(_err) => { let user = user::dto::ActiveModel { email: Set(user.email), @@ -50,49 +47,41 @@ impl<'a> Repository for UserRepository<'a> { }; Ok(User::from(user.insert(self.db_connection).await?)) } - } + }; } async fn get_by_id(&self, id: i32) -> Result { - match UserDTO::find_by_id(id).one( - self.db_connection).await? { - None => { - Err(Error::NotFoundError) - } - Some(user) => { - Ok(User::from(user)) - } + match UserDTO::find_by_id(id).one(self.db_connection).await? { + None => Err(Error::NotFoundError), + Some(user) => Ok(User::from(user)), } } async fn get_by_email(&self, email: &str) -> Result { - match UserDTO::find().filter( - user::dto::Column::Email.eq(email)).one( - self.db_connection).await? { - None => { - Err(Error::NotFoundError) - } - Some(user) => { - Ok(User::from(user)) - } + match UserDTO::find() + .filter(user::dto::Column::Email.eq(email)) + .one(self.db_connection) + .await? + { + None => Err(Error::NotFoundError), + Some(user) => Ok(User::from(user)), } } async fn delete_by_id(&self, id: i32) -> Result<()> { - let _ = UserDTO::delete_by_id(id).exec(self.db_connection) - .await?; + let _ = UserDTO::delete_by_id(id).exec(self.db_connection).await?; Ok(()) } } #[cfg(test)] mod tests { - use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction}; use crate::domain::user::entity::User; use crate::domain::user::repository::Repository; use crate::infra::database::model::user::dto; - use crate::util::error::Result; use crate::infra::database::model::user::repository::UserRepository; + use crate::util::error::Result; + use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction}; #[tokio::test] async fn test_user_repository_query_sql_statement() -> Result<()> { @@ -106,7 +95,8 @@ mod tests { id: 2, email: "fake_email".to_string(), }], - ]).into_connection(); + ]) + .into_connection(); let user_repository = UserRepository::new(&db); assert_eq!( @@ -155,15 +145,15 @@ mod tests { id: 3, email: "fake_email".to_string(), }], - ]).append_exec_results([ - MockExecResult{ + ]) + .append_exec_results([MockExecResult { last_insert_id: 3, rows_affected: 1, - } - ]).into_connection(); + }]) + .into_connection(); let user_repository = UserRepository::new(&db); - let user = User{ + let user = User { id: 0, email: "fake_string".to_string(), }; @@ -201,17 +191,15 @@ mod tests { #[tokio::test] async fn test_user_repository_delete_sql_statement() -> Result<()> { let db = MockDatabase::new(DatabaseBackend::MySql) - .append_query_results([ - vec![dto::Model { - id: 1, - email: "fake_email".to_string(), - }], - ]).append_exec_results([ - MockExecResult{ + .append_query_results([vec![dto::Model { + id: 1, + email: "fake_email".to_string(), + }]]) + .append_exec_results([MockExecResult { last_insert_id: 1, rows_affected: 1, - } - ]).into_connection(); + }]) + .into_connection(); let user_repository = UserRepository::new(&db); assert_eq!(user_repository.delete_by_id(1).await?, ()); diff --git a/src/infra/database/model/x509_crl_content/dto.rs b/src/infra/database/model/x509_crl_content/dto.rs index 45e4e33..9dfdd50 100644 --- a/src/infra/database/model/x509_crl_content/dto.rs +++ b/src/infra/database/model/x509_crl_content/dto.rs @@ -13,15 +13,14 @@ * * // See the Mulan PSL v2 for more details. * */ -use chrono::{DateTime, Utc}; -use crate::domain::datakey::entity::{X509CRL}; +use crate::domain::datakey::entity::X509CRL; use crate::util::error::Error; +use chrono::{DateTime, Utc}; -use sqlx::types::chrono; +use crate::util::key::{decode_hex_string_to_u8, encode_u8_to_hex_string}; use sea_orm::entity::prelude::*; use serde::{Deserialize, Serialize}; -use crate::util::key::{decode_hex_string_to_u8, encode_u8_to_hex_string}; - +use sqlx::types::chrono; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] #[sea_orm(table_name = "x509_crl_content")] @@ -62,13 +61,12 @@ impl TryFrom for X509CRL { } } - #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] pub enum Relation { #[sea_orm( - belongs_to = "super::super::datakey::dto::Entity", - from = "Column::CaId", - to = "super::super::datakey::dto::Column::Id" + belongs_to = "super::super::datakey::dto::Entity", + from = "Column::CaId", + to = "super::super::datakey::dto::Column::Id" )] Datakey, } @@ -79,4 +77,3 @@ impl Related for Entity { } } impl ActiveModelBehavior for ActiveModel {} - diff --git a/src/infra/database/model/x509_revoked_key/dto.rs b/src/infra/database/model/x509_revoked_key/dto.rs index dff5d85..dff183f 100644 --- a/src/infra/database/model/x509_revoked_key/dto.rs +++ b/src/infra/database/model/x509_revoked_key/dto.rs @@ -13,13 +13,12 @@ * * // See the Mulan PSL v2 for more details. * */ -use std::str::FromStr; -use chrono::{DateTime, Utc}; use crate::domain::datakey::entity::{RevokedKey, X509RevokeReason}; use crate::util::error::Error; +use chrono::{DateTime, Utc}; use sea_orm::entity::prelude::*; use serde::{Deserialize, Serialize}; - +use std::str::FromStr; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] #[sea_orm(table_name = "x509_keys_revoked")] @@ -59,7 +58,7 @@ mod tests { #[test] fn test_revoked_key_dto_conversion() { let now = Utc::now(); - let dto = Model{ + let dto = Model { id: 0, key_id: 1, ca_id: 2, diff --git a/src/infra/database/pool.rs b/src/infra/database/pool.rs index f63d770..48979fc 100644 --- a/src/infra/database/pool.rs +++ b/src/infra/database/pool.rs @@ -16,11 +16,11 @@ use config::Value; use once_cell::sync::OnceCell; -use sqlx::mysql::{MySql}; +use sea_orm::{ConnectOptions, Database, DatabaseConnection}; +use sqlx::mysql::MySql; use sqlx::pool::Pool; use std::collections::HashMap; use std::time::Duration; -use sea_orm::{DatabaseConnection, ConnectOptions, Database}; use crate::util::error::{Error, Result}; pub type DbPool = Pool; @@ -61,8 +61,13 @@ pub async fn create_pool(config: &HashMap) -> Result<()> { .sqlx_logging(true) .sqlx_logging_level(log::LevelFilter::Info); - DB_CONNECTION.set(Database::connect(opt).await?).expect("database connection configured"); - get_db_connection()?.ping().await.expect("database connection failed"); + DB_CONNECTION + .set(Database::connect(opt).await?) + .expect("database connection configured"); + get_db_connection()? + .ping() + .await + .expect("database connection failed"); Ok(()) } pub fn get_db_connection() -> Result<&'static DatabaseConnection> { @@ -76,8 +81,8 @@ pub fn get_db_connection() -> Result<&'static DatabaseConnection> { #[cfg(test)] mod tests { - use testcontainers::clients; use crate::util::error::Result; + use testcontainers::clients; use testcontainers::core::WaitFor; use testcontainers::images::generic::GenericImage; @@ -93,13 +98,19 @@ mod tests { let database = docker.run(image.clone()); let sqlx_image = GenericImage::new("tommylike/sqlx-cli", "0.7.1.1") - .with_env_var("DATABASE_HOST", database.get_bridge_ip_address().to_string()) + .with_env_var( + "DATABASE_HOST", + database.get_bridge_ip_address().to_string(), + ) .with_env_var("DATABASE_PORT", "3306") .with_env_var("DATABASE_USER", "test") .with_env_var("DATABASE_PASSWORD", "test") .with_env_var("DATABASE_NAME", "signatrust") - .with_volume("./migrations/", "/app/migrations/").with_entrypoint("/app/run_migrations.sh") - .with_wait_for(WaitFor::message_on_stdout("Applied 20230727020628/migrate extend-datakey-name")); + .with_volume("./migrations/", "/app/migrations/") + .with_entrypoint("/app/run_migrations.sh") + .with_wait_for(WaitFor::message_on_stdout( + "Applied 20230727020628/migrate extend-datakey-name", + )); let _migration = docker.run(sqlx_image.clone()); Ok(()) } diff --git a/src/infra/encryption/algorithm/aes.rs b/src/infra/encryption/algorithm/aes.rs index b5abefa..1b137b4 100644 --- a/src/infra/encryption/algorithm/aes.rs +++ b/src/infra/encryption/algorithm/aes.rs @@ -16,13 +16,13 @@ use crate::domain::encryptor::{Algorithm, Encryptor}; use crate::util::error::Error; +use crate::util::error::Result; use aes_gcm_siv::{ aead::{Aead, KeyInit, OsRng}, Aes256GcmSiv, }; use generic_array::GenericArray; use rand::{thread_rng, Rng}; -use crate::util::error::Result; pub const NONCE_LENGTH: usize = 12; pub const KEY_LENGTH: usize = 32; @@ -47,7 +47,7 @@ impl Encryptor for Aes256GcmEncryptor { fn encrypt(&self, key: Vec, content: Vec) -> Result> { if key.len() != KEY_LENGTH { - return Err(Error::EncodeError("key size not matched".to_string())) + return Err(Error::EncodeError("key size not matched".to_string())); } let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&key)); let random = self.generate_nonce_bytes(); @@ -68,7 +68,7 @@ impl Encryptor for Aes256GcmEncryptor { )); } if key.len() != KEY_LENGTH { - return Err(Error::EncodeError("key size not matched".to_string())) + return Err(Error::EncodeError("key size not matched".to_string())); } let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&key)); let nonce = GenericArray::from_slice(&content[..NONCE_LENGTH]); @@ -108,14 +108,22 @@ mod test { let aes = Aes256GcmEncryptor::default(); let key1 = aes.generate_key(); let content = "fake_content".as_bytes(); - let encoded_1 = aes.encrypt(key1.clone(), content.to_vec()).expect("encode should be successful"); - let encoded_2 = aes.encrypt(key1.clone(), content.to_vec()).expect("encode should be successful"); + let encoded_1 = aes + .encrypt(key1.clone(), content.to_vec()) + .expect("encode should be successful"); + let encoded_2 = aes + .encrypt(key1.clone(), content.to_vec()) + .expect("encode should be successful"); assert_ne!(encoded_1, encoded_2); assert_eq!(encoded_1.len(), encoded_1.len()); assert_ne!(encoded_1, content); assert_ne!(encoded_2, content); - let decode_1 = aes.decrypt(key1.clone(), encoded_1).expect("decode should be successful"); - let decode_2 = aes.decrypt(key1.clone(), encoded_2).expect("decode should be successful"); + let decode_1 = aes + .decrypt(key1.clone(), encoded_1) + .expect("decode should be successful"); + let decode_2 = aes + .decrypt(key1.clone(), encoded_2) + .expect("decode should be successful"); assert_eq!(content, decode_1); assert_eq!(content, decode_2); } @@ -126,14 +134,22 @@ mod test { let key1 = aes.generate_key(); let key2 = aes.generate_key(); let content = "fake_content".as_bytes(); - let encoded_1 = aes.encrypt(key1.clone(), content.to_vec()).expect("encode should be successful"); - let encoded_2 = aes.encrypt(key2.clone(), content.to_vec()).expect("encode should be successful"); + let encoded_1 = aes + .encrypt(key1.clone(), content.to_vec()) + .expect("encode should be successful"); + let encoded_2 = aes + .encrypt(key2.clone(), content.to_vec()) + .expect("encode should be successful"); assert_ne!(encoded_1, encoded_2); assert_eq!(encoded_1.len(), encoded_1.len()); assert_ne!(encoded_1, content); assert_ne!(encoded_2, content); - let decode_1 = aes.decrypt(key1.clone(), encoded_1).expect("decode should be successful"); - let decode_2 = aes.decrypt(key2.clone(), encoded_2).expect("decode should be successful"); + let decode_1 = aes + .decrypt(key1.clone(), encoded_1) + .expect("decode should be successful"); + let decode_2 = aes + .decrypt(key2.clone(), encoded_2) + .expect("decode should be successful"); assert_eq!(content, decode_1); assert_eq!(content, decode_2); } @@ -144,13 +160,21 @@ mod test { let key1 = aes.generate_key(); let content_1 = "fake_content1".as_bytes(); let content_2 = "fake_content 2".as_bytes(); - let encoded_1 = aes.encrypt(key1.clone(), content_1.to_vec()).expect("encode should be successful"); - let encoded_2 = aes.encrypt(key1.clone(), content_2.to_vec()).expect("encode should be successful"); + let encoded_1 = aes + .encrypt(key1.clone(), content_1.to_vec()) + .expect("encode should be successful"); + let encoded_2 = aes + .encrypt(key1.clone(), content_2.to_vec()) + .expect("encode should be successful"); assert_ne!(encoded_1, encoded_2); assert_ne!(encoded_1, content_1); assert_ne!(encoded_2, content_2); - let decode_1 = aes.decrypt(key1.clone(), encoded_1).expect("decode should be successful"); - let decode_2 = aes.decrypt(key1.clone(), encoded_2).expect("decode should be successful"); + let decode_1 = aes + .decrypt(key1.clone(), encoded_1) + .expect("decode should be successful"); + let decode_2 = aes + .decrypt(key1.clone(), encoded_2) + .expect("decode should be successful"); assert_eq!(content_1, decode_1); assert_eq!(content_2, decode_2); } @@ -161,16 +185,26 @@ mod test { let key1 = aes.generate_key(); let encoded = "123456789abc".as_bytes(); let invalid = "invalid_encoded_content_although_long_enough".as_bytes(); - let _ = aes.decrypt(key1.clone(), vec![]).expect_err("decode should fail due to content not long enough"); - let _ = aes.decrypt(key1.clone(), encoded.to_vec()).expect_err("decode should fail due to content not long enough"); - let _ = aes.decrypt(key1.clone(), invalid.to_vec()).expect_err("decode should fail due to content invalid"); + let _ = aes + .decrypt(key1.clone(), vec![]) + .expect_err("decode should fail due to content not long enough"); + let _ = aes + .decrypt(key1.clone(), encoded.to_vec()) + .expect_err("decode should fail due to content not long enough"); + let _ = aes + .decrypt(key1.clone(), invalid.to_vec()) + .expect_err("decode should fail due to content invalid"); } #[test] fn test_encrypt_decrypt_with_invalid_key_size() { let aes = Aes256GcmEncryptor::default(); let invalid_key = "invalid_key".as_bytes(); - let _ = aes.encrypt(invalid_key.to_vec(), vec![]).expect_err("encode should fail due to key size invalid"); - let _ = aes.decrypt(invalid_key.to_vec(), vec![]).expect_err("decode should fail due to key size invalid"); + let _ = aes + .encrypt(invalid_key.to_vec(), vec![]) + .expect_err("encode should fail due to key size invalid"); + let _ = aes + .decrypt(invalid_key.to_vec(), vec![]) + .expect_err("decode should fail due to key size invalid"); } -} \ No newline at end of file +} diff --git a/src/infra/encryption/algorithm/factory.rs b/src/infra/encryption/algorithm/factory.rs index ce7178a..2b93e33 100644 --- a/src/infra/encryption/algorithm/factory.rs +++ b/src/infra/encryption/algorithm/factory.rs @@ -14,12 +14,11 @@ * */ -use crate::infra::encryption::algorithm::aes::Aes256GcmEncryptor; use crate::domain::encryptor::{Algorithm, Encryptor}; -use crate::util::error::{Result}; +use crate::infra::encryption::algorithm::aes::Aes256GcmEncryptor; +use crate::util::error::Result; use std::str::FromStr; - pub struct AlgorithmFactory {} impl AlgorithmFactory { @@ -39,7 +38,8 @@ mod test { #[test] fn test_algorithm_factory() { assert!(AlgorithmFactory::new_algorithm("aes-512-gcm").is_err()); - let algo = AlgorithmFactory::new_algorithm("aes256gsm").expect("algorithm from valid string should succeed"); + let algo = AlgorithmFactory::new_algorithm("aes256gsm") + .expect("algorithm from valid string should succeed"); assert_eq!(algo.algorithm(), Algorithm::Aes256GSM); } } diff --git a/src/infra/encryption/dummy_engine.rs b/src/infra/encryption/dummy_engine.rs index 4551de5..29b07d0 100644 --- a/src/infra/encryption/dummy_engine.rs +++ b/src/infra/encryption/dummy_engine.rs @@ -16,7 +16,6 @@ use crate::domain::encryption_engine::EncryptionEngine; use crate::util::error::Result; use async_trait::async_trait; - #[derive(Default)] pub struct DummyEngine {} @@ -41,4 +40,4 @@ impl EncryptionEngine for DummyEngine { warn!("dummy engine used for encryption, please don't use it in production environment"); Ok(content) } -} \ No newline at end of file +} diff --git a/src/infra/encryption/engine.rs b/src/infra/encryption/engine.rs index ccb2b02..96e8b40 100644 --- a/src/infra/encryption/engine.rs +++ b/src/infra/encryption/engine.rs @@ -14,17 +14,17 @@ * */ -use crate::domain::encryptor::Encryptor; -use crate::domain::encryption_engine::EncryptionEngine; use crate::domain::clusterkey::entity::{ClusterKey, SecClusterKey}; use crate::domain::clusterkey::repository::Repository as ClusterKeyRepository; +use crate::domain::encryption_engine::EncryptionEngine; +use crate::domain::encryptor::Encryptor; use crate::util::error::{Error, Result}; use crate::util::key; use async_trait::async_trait; +use chrono::{Duration, Utc}; use config::Value; use std::collections::HashMap; use std::sync::Arc; -use chrono::{Utc, Duration}; use tokio::sync::RwLock; use crate::domain::kms_provider::KMSProvider; @@ -35,7 +35,7 @@ pub struct EncryptionEngineWithClusterKey where C: ClusterKeyRepository, K: KMSProvider + ?Sized, - E: Encryptor + ?Sized + E: Encryptor + ?Sized, { //cluster key repository cluster_repository: C, @@ -43,7 +43,7 @@ where encryptor: Box, rotate_in_days: i64, latest_cluster_key: Arc>, - cluster_key_container: Arc>> // cluster key id -> cluster key + cluster_key_container: Arc>>, // cluster key id -> cluster key } /// considering we have rotated cluster key for safety concern @@ -58,20 +58,25 @@ impl EncryptionEngineWithClusterKey where C: ClusterKeyRepository, K: KMSProvider + ?Sized, - E: Encryptor + ?Sized + E: Encryptor + ?Sized, { pub fn new( cluster_repository: C, encryptor: Box, config: &HashMap, - kms_provider: Box) -> Result { + kms_provider: Box, + ) -> Result { let rotate_in_days = config .get("rotate_in_days") .expect("rotate in days should configured") .to_string() - .parse().unwrap_or(DEFAULT_ROTATE_IN_DAYS); - if rotate_in_days < DEFAULT_ROTATE_IN_DAYS { - return Err(Error::ConfigError(format!("rotate in days should greater than {}", rotate_in_days))); + .parse() + .unwrap_or(DEFAULT_ROTATE_IN_DAYS); + if rotate_in_days < DEFAULT_ROTATE_IN_DAYS { + return Err(Error::ConfigError(format!( + "rotate in days should greater than {}", + rotate_in_days + ))); } info!("cluster key will be rotated in {} days", rotate_in_days); Ok(EncryptionEngineWithClusterKey { @@ -80,7 +85,7 @@ where rotate_in_days, latest_cluster_key: Arc::new(RwLock::new(SecClusterKey::default())), kms_provider, - cluster_key_container: Arc::new(RwLock::new(HashMap::new())) + cluster_key_container: Arc::new(RwLock::new(HashMap::new())), }) } async fn append_cluster_key_hex(&self, data: &mut Vec) -> Vec { @@ -97,19 +102,28 @@ where //convert the cluster back and obtain from database, hard code here. let cluster_id: i32 = (data[0] as i32) * 256 + data[1] as i32; if let Some(cluster_key) = self.cluster_key_container.read().await.get(&cluster_id) { - return Ok((*cluster_key).clone()) + return Ok((*cluster_key).clone()); } - let cluster_key = SecClusterKey::load( self.cluster_repository.get_by_id(cluster_id).await?, &self.kms_provider).await?; - self.cluster_key_container.write().await.insert(cluster_id, cluster_key.clone()); + let cluster_key = SecClusterKey::load( + self.cluster_repository.get_by_id(cluster_id).await?, + &self.kms_provider, + ) + .await?; + self.cluster_key_container + .write() + .await + .insert(cluster_id, cluster_key.clone()); Ok(cluster_key) } async fn generate_new_key(&self) -> Result<()> { //generate new key identified with date time let cluster_key = ClusterKey::new( - self.kms_provider.encode( - key::encode_u8_to_hex_string(&self.encryptor.generate_key())) - .await?.as_bytes().to_vec(), + self.kms_provider + .encode(key::encode_u8_to_hex_string(&self.encryptor.generate_key())) + .await? + .as_bytes() + .to_vec(), self.encryptor.algorithm().to_string(), )?; //insert when no records @@ -124,7 +138,10 @@ where "can't find latest cluster key from database".to_string(), )) } - Some(cluster) => *self.latest_cluster_key.write().await = SecClusterKey::load(cluster, &self.kms_provider).await?, + Some(cluster) => { + *self.latest_cluster_key.write().await = + SecClusterKey::load(cluster, &self.kms_provider).await? + } }; Ok(()) } @@ -135,7 +152,7 @@ impl EncryptionEngine for EncryptionEngineWithClusterKey where C: ClusterKeyRepository, K: KMSProvider + ?Sized, - E: Encryptor + ?Sized + E: Encryptor + ?Sized, { async fn initialize(&mut self) -> Result<()> { //generate new cluster keys only when there is no db record match the date @@ -144,21 +161,32 @@ where .get_latest(&self.encryptor.algorithm().to_string()) .await?; match key { - Some(k) => *self.latest_cluster_key.write().await = SecClusterKey::load(k, &self.kms_provider).await?, + Some(k) => { + *self.latest_cluster_key.write().await = + SecClusterKey::load(k, &self.kms_provider).await? + } None => { self.generate_new_key().await?; } } - info!("cluster key is found or generated : {}", self.latest_cluster_key.read().await); + info!( + "cluster key is found or generated : {}", + self.latest_cluster_key.read().await + ); Ok(()) } async fn rotate_key(&mut self) -> Result { - if Utc::now() < self.latest_cluster_key.read().await.create_at + Duration::days(self.rotate_in_days) { + if Utc::now() + < self.latest_cluster_key.read().await.create_at + Duration::days(self.rotate_in_days) + { return Ok(false); } self.generate_new_key().await?; - info!("cluster key is rotated : {}", self.latest_cluster_key.read().await); + info!( + "cluster key is rotated : {}", + self.latest_cluster_key.read().await + ); Ok(true) } @@ -167,9 +195,15 @@ where return Ok(content); } //always use latest cluster key to encode data - let mut secret = self - .encryptor - .encrypt(self.latest_cluster_key.read().await.data.unsecure().to_owned(), content)?; + let mut secret = self.encryptor.encrypt( + self.latest_cluster_key + .read() + .await + .data + .unsecure() + .to_owned(), + content, + )?; Ok(self.append_cluster_key_hex(&mut secret).await) } diff --git a/src/infra/encryption/mod.rs b/src/infra/encryption/mod.rs index 64efc98..145d21c 100644 --- a/src/infra/encryption/mod.rs +++ b/src/infra/encryption/mod.rs @@ -1,3 +1,3 @@ pub mod algorithm; -pub mod engine; pub mod dummy_engine; +pub mod engine; diff --git a/src/infra/kms/dummy.rs b/src/infra/kms/dummy.rs index 33e225e..991c952 100644 --- a/src/infra/kms/dummy.rs +++ b/src/infra/kms/dummy.rs @@ -15,13 +15,12 @@ */ use crate::domain::kms_provider::KMSProvider; -use crate::util::error::{Result}; +use crate::util::error::Result; +use async_trait::async_trait; use config::Value; use std::collections::HashMap; -use async_trait::async_trait; -pub struct DummyKMS { -} +pub struct DummyKMS {} impl DummyKMS { pub fn new(_config: &HashMap) -> Result { @@ -55,4 +54,4 @@ mod test { let decoded_content = dummy_kms.decode(encoded_content).await.unwrap(); assert_eq!(content, decoded_content); } -} \ No newline at end of file +} diff --git a/src/infra/kms/factory.rs b/src/infra/kms/factory.rs index d40bf3e..0d804b5 100644 --- a/src/infra/kms/factory.rs +++ b/src/infra/kms/factory.rs @@ -14,15 +14,14 @@ * */ -use crate::infra::kms::huaweicloud::HuaweiCloudKMS; -use crate::infra::kms::dummy::DummyKMS; use crate::domain::kms_provider::{KMSProvider, KMSType}; -use crate::util::error::{Result}; +use crate::infra::kms::dummy::DummyKMS; +use crate::infra::kms::huaweicloud::HuaweiCloudKMS; +use crate::util::error::Result; use config::Value; use std::collections::HashMap; use std::str::FromStr; - pub struct KMSProviderFactory {} impl KMSProviderFactory { @@ -52,7 +51,7 @@ mod test { config.insert("type".to_string(), Value::from("not_existed")); assert!(KMSProviderFactory::new_provider(&config).is_err()); config.insert("type".to_string(), Value::from("dummy")); - KMSProviderFactory::new_provider(&config).expect("kms provider from valid string should succeed"); + KMSProviderFactory::new_provider(&config) + .expect("kms provider from valid string should succeed"); } } - diff --git a/src/infra/kms/huaweicloud.rs b/src/infra/kms/huaweicloud.rs index e7d9a6f..e7c0ce9 100644 --- a/src/infra/kms/huaweicloud.rs +++ b/src/infra/kms/huaweicloud.rs @@ -19,11 +19,11 @@ use crate::util::error::{Error, Result}; use async_trait::async_trait; use config::Value; use reqwest::{header::HeaderValue, Client, StatusCode}; +use secstr::*; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; use tokio::sync::Mutex; -use secstr::*; static SIGN_HEADER: &str = "x-auth-token"; static AUTH_HEADER: &str = "x-subject-token"; @@ -65,10 +65,12 @@ impl HuaweiCloudKMS { .get("username") .unwrap_or(&Value::default()) .to_string(), - password: SecUtf8::from(config - .get("password") - .unwrap_or(&Value::default()) - .to_string()), + password: SecUtf8::from( + config + .get("password") + .unwrap_or(&Value::default()) + .to_string(), + ), domain: config .get("domain") .unwrap_or(&Value::default()) @@ -232,7 +234,10 @@ mod test { use super::*; use mockito; - fn get_kms_config(iam_endpoint: Option, kms_endpoint: Option) -> HashMap { + fn get_kms_config( + iam_endpoint: Option, + kms_endpoint: Option, + ) -> HashMap { let mut config: HashMap = HashMap::new(); config.insert("kms_id".to_string(), Value::from("fake_kms_id")); config.insert("username".to_string(), Value::from("fake_username")); @@ -241,14 +246,22 @@ mod test { config.insert("project_name".to_string(), Value::from("fake_project_name")); config.insert("project_id".to_string(), Value::from("fake_project_id")); match iam_endpoint { - None => {config.insert("iam_endpoint".to_string(), Value::from("fake_endpoint"));} - Some(value) => {config.insert("iam_endpoint".to_string(), Value::from(value));} + None => { + config.insert("iam_endpoint".to_string(), Value::from("fake_endpoint")); + } + Some(value) => { + config.insert("iam_endpoint".to_string(), Value::from(value)); + } } match kms_endpoint { - None => {config.insert("kms_endpoint".to_string(), Value::from("fake_endpoint"));} - Some(value) => {config.insert("kms_endpoint".to_string(), Value::from(value));} + None => { + config.insert("kms_endpoint".to_string(), Value::from("fake_endpoint")); + } + Some(value) => { + config.insert("kms_endpoint".to_string(), Value::from(value)); + } } - return config + return config; } #[tokio::test] @@ -261,27 +274,32 @@ mod test { let config = get_kms_config(Some(iam_url.clone()), Some(kms_url)); // Mock auth request - let mock_auth = iam_server.mock("POST", "/v3/auth/tokens") + let mock_auth = iam_server + .mock("POST", "/v3/auth/tokens") .with_status(201) .with_header(AUTH_HEADER, "fake_auth_header") .create(); - let mock_encode = kms_server.mock("POST", "/v1.0/fake_project_id/kms/encrypt-data") + let mock_encode = kms_server + .mock("POST", "/v1.0/fake_project_id/kms/encrypt-data") .with_status(200) .match_header(SIGN_HEADER, "fake_auth_header") .with_body(r#"{"key_id": "123", "cipher_text": "encoded"}"#) .create(); //create kms client - let kms_client = HuaweiCloudKMS::new(&config).expect("create huaweicloud client should be successful"); - let result = kms_client.encode("raw_content".to_string()).await.expect("request invoke should be successful"); + let kms_client = + HuaweiCloudKMS::new(&config).expect("create huaweicloud client should be successful"); + let result = kms_client + .encode("raw_content".to_string()) + .await + .expect("request invoke should be successful"); assert_eq!("encoded", result); mock_auth.assert(); mock_encode.assert(); } - #[tokio::test] async fn test_huaweicloud_decode_successful() { // Request a new server from the pool @@ -292,20 +310,26 @@ mod test { let config = get_kms_config(Some(iam_url.clone()), Some(kms_url)); // Mock auth request - let mock_auth = iam_server.mock("POST", "/v3/auth/tokens") + let mock_auth = iam_server + .mock("POST", "/v3/auth/tokens") .with_status(201) .with_header(AUTH_HEADER, "fake_auth_header") .create(); - let mock_decode = kms_server.mock("POST", "/v1.0/fake_project_id/kms/decrypt-data") + let mock_decode = kms_server + .mock("POST", "/v1.0/fake_project_id/kms/decrypt-data") .with_status(200) .match_header(SIGN_HEADER, "fake_auth_header") .with_body(r#"{"key_id": "123", "plain_text": "decoded", "plain_text_base64": "123"}"#) .create(); //create kms client - let kms_client = HuaweiCloudKMS::new(&config).expect("create huaweicloud client should be successful"); - let result = kms_client.decode("raw_content".to_string()).await.expect("request invoke should be successful"); + let kms_client = + HuaweiCloudKMS::new(&config).expect("create huaweicloud client should be successful"); + let result = kms_client + .decode("raw_content".to_string()) + .await + .expect("request invoke should be successful"); assert_eq!("decoded", result); mock_auth.assert(); @@ -320,7 +344,8 @@ mod test { let config = get_kms_config(Some(url.clone()), None); // Mock auth request - let mock_auth = server.mock("POST", "/v3/auth/tokens") + let mock_auth = server + .mock("POST", "/v3/auth/tokens") .with_status(201) .create(); @@ -328,7 +353,8 @@ mod test { let fake_request = json!({ "fake_attribute": "123", }); - let mock_request = server.mock("POST", "/kms/fake_endpoint") + let mock_request = server + .mock("POST", "/kms/fake_endpoint") .with_status(200) .match_header(SIGN_HEADER, "fake_auth_header") .match_body(mockito::Matcher::Json(fake_request.clone())) @@ -336,11 +362,19 @@ mod test { .create(); //create kms client - let kms_client = HuaweiCloudKMS::new(&config).expect("create huaweicloud client should be successful"); - kms_client.auth_token_cache.lock().await.push_str("fake_auth_header"); + let kms_client = + HuaweiCloudKMS::new(&config).expect("create huaweicloud client should be successful"); + kms_client + .auth_token_cache + .lock() + .await + .push_str("fake_auth_header"); let request_url = format!("{}/kms/fake_endpoint", url); - let result = kms_client.do_request(&request_url, &fake_request).await.expect("request invoke should be successful"); + let result = kms_client + .do_request(&request_url, &fake_request) + .await + .expect("request invoke should be successful"); let decoded: DecodeData = serde_json::from_value(result).expect("deserialize should ok"); assert_eq!("123", decoded.key_id); @@ -359,7 +393,8 @@ mod test { let config = get_kms_config(Some(url.clone()), None); // Mock auth request - let mock_auth = server.mock("POST", "/v3/auth/tokens") + let mock_auth = server + .mock("POST", "/v3/auth/tokens") .with_status(201) .with_header(AUTH_HEADER, "fake_auth_header") .create(); @@ -368,7 +403,8 @@ mod test { let fake_request = json!({ "fake_attribute": "123", }); - let mock_request = server.mock("POST", "/kms/decrypt-data") + let mock_request = server + .mock("POST", "/kms/decrypt-data") .with_status(200) .match_header(SIGN_HEADER, "fake_auth_header") .match_body(mockito::Matcher::Json(fake_request.clone())) @@ -376,9 +412,13 @@ mod test { .create(); //create kms client - let kms_client = HuaweiCloudKMS::new(&config).expect("create huaweicloud client should be successful"); + let kms_client = + HuaweiCloudKMS::new(&config).expect("create huaweicloud client should be successful"); let request_url = format!("{}/kms/decrypt-data", url); - let result = kms_client.do_request(&request_url, &fake_request).await.expect("request invoke should be successful"); + let result = kms_client + .do_request(&request_url, &fake_request) + .await + .expect("request invoke should be successful"); let decoded: DecodeData = serde_json::from_value(result).expect("deserialize should ok"); assert_eq!("123", decoded.key_id); @@ -397,7 +437,8 @@ mod test { let config = get_kms_config(Some(url.clone()), None); // Mock auth request - let mock_auth = server.mock("POST", "/v3/auth/tokens") + let mock_auth = server + .mock("POST", "/v3/auth/tokens") .with_status(201) .with_header(AUTH_HEADER, "fake_auth_header") .create(); @@ -406,7 +447,8 @@ mod test { let fake_request = json!({ "fake_attribute": "123", }); - let mock_request = server.mock("POST", "/kms/fake_endpoint") + let mock_request = server + .mock("POST", "/kms/fake_endpoint") .with_status(403) .match_header(SIGN_HEADER, "fake_auth_header") .match_body(mockito::Matcher::Json(fake_request.clone())) @@ -414,9 +456,13 @@ mod test { .create(); //create kms client - let kms_client = HuaweiCloudKMS::new(&config).expect("create huaweicloud client should be successful"); + let kms_client = + HuaweiCloudKMS::new(&config).expect("create huaweicloud client should be successful"); let request_url = format!("{}/kms/fake_endpoint", url); - let _result = kms_client.do_request(&request_url, &fake_request).await.expect_err("always failed to invoke request"); + let _result = kms_client + .do_request(&request_url, &fake_request) + .await + .expect_err("always failed to invoke request"); //auth and request should be invoked twice. mock_auth.expect_at_least(2).assert(); @@ -431,16 +477,21 @@ mod test { let config = get_kms_config(Some(url), None); // Create a mock server - let mock = server.mock("POST", "/v3/auth/tokens") + let mock = server + .mock("POST", "/v3/auth/tokens") .with_status(500) .create(); //create kms client - let kms_client = HuaweiCloudKMS::new(&config).expect("create huaweicloud client should be successful"); + let kms_client = + HuaweiCloudKMS::new(&config).expect("create huaweicloud client should be successful"); //test auth request assert_eq!(true, kms_client.auth_token_cache.lock().await.is_empty()); - kms_client.auth_request().await.expect_err("auth request failed with 500 status code"); + kms_client + .auth_request() + .await + .expect_err("auth request failed with 500 status code"); mock.assert(); } @@ -476,19 +527,27 @@ mod test { }); // Create a mock server - let mock = server.mock("POST", "/v3/auth/tokens") + let mock = server + .mock("POST", "/v3/auth/tokens") .with_status(201) .with_header(AUTH_HEADER, "fake_auth_header") .match_body(mockito::Matcher::Json(request_body)) .create(); //create kms client - let kms_client = HuaweiCloudKMS::new(&config).expect("create huaweicloud client should be successful"); + let kms_client = + HuaweiCloudKMS::new(&config).expect("create huaweicloud client should be successful"); //test auth request assert_eq!(true, kms_client.auth_token_cache.lock().await.is_empty()); - kms_client.auth_request().await.expect("auth request should be successful"); - assert_eq!("fake_auth_header", kms_client.auth_token_cache.lock().await.as_str()); + kms_client + .auth_request() + .await + .expect("auth request should be successful"); + assert_eq!( + "fake_auth_header", + kms_client.auth_token_cache.lock().await.as_str() + ); mock.assert(); } } diff --git a/src/infra/kms/mod.rs b/src/infra/kms/mod.rs index b9ed6fb..f44f91f 100644 --- a/src/infra/kms/mod.rs +++ b/src/infra/kms/mod.rs @@ -1,3 +1,3 @@ +pub mod dummy; pub mod factory; pub mod huaweicloud; -pub mod dummy; diff --git a/src/infra/mod.rs b/src/infra/mod.rs index 80b886d..dd47004 100644 --- a/src/infra/mod.rs +++ b/src/infra/mod.rs @@ -1,5 +1,5 @@ -pub mod encryption; pub mod database; +pub mod encryption; pub mod kms; -pub mod sign_plugin; pub mod sign_backend; +pub mod sign_plugin; diff --git a/src/infra/sign_backend/factory.rs b/src/infra/sign_backend/factory.rs index 9804187..29a2e4c 100644 --- a/src/infra/sign_backend/factory.rs +++ b/src/infra/sign_backend/factory.rs @@ -14,24 +14,28 @@ * */ -use std::str::FromStr; use crate::domain::sign_service::{SignBackend, SignBackendType}; -use crate::util::error::{Result}; -use std::sync::{Arc, RwLock}; -use config::{Config}; -use sea_orm::DatabaseConnection; use crate::infra::sign_backend::memory::backend::MemorySignBackend; +use crate::util::error::Result; +use config::Config; +use sea_orm::DatabaseConnection; +use std::str::FromStr; +use std::sync::{Arc, RwLock}; pub struct SignBackendFactory {} impl SignBackendFactory { - pub async fn new_engine(config: Arc>, db_connection: &'static DatabaseConnection) -> Result> { - let engine_type = SignBackendType::from_str( - config.read()?.get_string("sign-backend.type")?.as_str(), - )?; + pub async fn new_engine( + config: Arc>, + db_connection: &'static DatabaseConnection, + ) -> Result> { + let engine_type = + SignBackendType::from_str(config.read()?.get_string("sign-backend.type")?.as_str())?; info!("sign backend configured with plugin {:?}", engine_type); match engine_type { - SignBackendType::Memory => Ok(Box::new(MemorySignBackend::new(config, db_connection).await?)), + SignBackendType::Memory => Ok(Box::new( + MemorySignBackend::new(config, db_connection).await?, + )), } } } diff --git a/src/infra/sign_backend/memory/backend.rs b/src/infra/sign_backend/memory/backend.rs index 176a8c5..8f914b5 100644 --- a/src/infra/sign_backend/memory/backend.rs +++ b/src/infra/sign_backend/memory/backend.rs @@ -14,7 +14,6 @@ * */ - use std::collections::HashMap; use crate::domain::sign_service::SignBackend; @@ -23,19 +22,18 @@ use std::sync::Arc; use config::Config; use std::sync::RwLock; +use crate::domain::datakey::entity::DataKey; +use crate::domain::datakey::entity::{RevokedKey, SecDataKey, INFRA_CONFIG_DOMAIN_NAME}; +use crate::domain::encryption_engine::EncryptionEngine; use crate::infra::database::model::clusterkey::repository; +use crate::infra::encryption::algorithm::factory::AlgorithmFactory; +use crate::infra::encryption::engine::EncryptionEngineWithClusterKey; use crate::infra::kms::factory; -use crate::infra::encryption::engine::{EncryptionEngineWithClusterKey}; -use crate::domain::encryption_engine::EncryptionEngine; -use crate::domain::datakey::entity::{INFRA_CONFIG_DOMAIN_NAME, RevokedKey, SecDataKey}; use crate::infra::sign_plugin::signers::Signers; -use crate::domain::datakey::entity::DataKey; use crate::util::error::{Error, Result}; use async_trait::async_trait; use chrono::{DateTime, Utc}; use sea_orm::DatabaseConnection; -use crate::infra::encryption::algorithm::factory::AlgorithmFactory; - /// Memory Sign Backend will perform all sensitive operations directly in host memory. pub struct MemorySignBackend { @@ -50,14 +48,18 @@ impl MemorySignBackend { /// 2. initialize the cluster repo /// 2. initialize the encryption engine including the cluster key /// 3. initialize the signing plugins - pub async fn new(server_config: Arc>, db_connection: &'static DatabaseConnection) -> Result { + pub async fn new( + server_config: Arc>, + db_connection: &'static DatabaseConnection, + ) -> Result { //initialize the kms backend let kms_provider = factory::KMSProviderFactory::new_provider( - &server_config.read()?.get_table("memory.kms-provider")? + &server_config.read()?.get_table("memory.kms-provider")?, )?; - let repository = - repository::ClusterKeyRepository::new(db_connection); - let engine_config = server_config.read()?.get_table("memory.encryption-engine")?; + let repository = repository::ClusterKeyRepository::new(db_connection); + let engine_config = server_config + .read()? + .get_table("memory.encryption-engine")?; let encryptor = AlgorithmFactory::new_algorithm( &engine_config .get("algorithm") @@ -68,13 +70,16 @@ impl MemorySignBackend { repository, encryptor, &engine_config, - kms_provider + kms_provider, )?; engine.initialize().await?; - let infra_configs = HashMap::from([ - (INFRA_CONFIG_DOMAIN_NAME.to_string(), server_config.read()?.get_string("control-server.domain_name")?), - ]); + let infra_configs = HashMap::from([( + INFRA_CONFIG_DOMAIN_NAME.to_string(), + server_config + .read()? + .get_string("control-server.domain_name")?, + )]); Ok(MemorySignBackend { server_config, @@ -88,7 +93,10 @@ impl MemorySignBackend { impl SignBackend for MemorySignBackend { async fn validate_and_update(&self, data_key: &mut DataKey) -> Result<()> { if let Err(err) = Signers::validate_and_update(data_key) { - return Err(Error::ParameterError(format!("failed to validate imported key content: {}", err))); + return Err(Error::ParameterError(format!( + "failed to validate imported key content: {}", + err + ))); } data_key.private_key = self.engine.encode(data_key.private_key.clone()).await?; data_key.public_key = self.engine.encode(data_key.public_key.clone()).await?; @@ -98,8 +106,8 @@ impl SignBackend for MemorySignBackend { async fn generate_keys(&self, data_key: &mut DataKey) -> Result<()> { let sec_key = SecDataKey::load(data_key, &self.engine).await?; - let content = Signers::load_from_data_key(&data_key.key_type, sec_key)?.generate_keys( - &data_key.key_type, &self.infra_configs)?; + let content = Signers::load_from_data_key(&data_key.key_type, sec_key)? + .generate_keys(&data_key.key_type, &self.infra_configs)?; data_key.private_key = self.engine.encode(content.private_key).await?; data_key.public_key = self.engine.encode(content.public_key).await?; data_key.certificate = self.engine.encode(content.certificate).await?; @@ -112,7 +120,12 @@ impl SignBackend for MemorySignBackend { self.engine.rotate_key().await } - async fn sign(&self, data_key: &DataKey, content: Vec, options: HashMap) -> Result> { + async fn sign( + &self, + data_key: &DataKey, + content: Vec, + options: HashMap, + ) -> Result> { let sec_key = SecDataKey::load(data_key, &self.engine).await?; Signers::load_from_data_key(&data_key.key_type, sec_key)?.sign(content, options) } @@ -123,8 +136,18 @@ impl SignBackend for MemorySignBackend { Ok(()) } - async fn generate_crl_content(&self, data_key: &DataKey, revoked_keys: Vec, last_update: DateTime, next_update: DateTime) -> Result> { + async fn generate_crl_content( + &self, + data_key: &DataKey, + revoked_keys: Vec, + last_update: DateTime, + next_update: DateTime, + ) -> Result> { let sec_key = SecDataKey::load(data_key, &self.engine).await?; - Signers::load_from_data_key(&data_key.key_type, sec_key)?.generate_crl_content(revoked_keys, last_update, next_update) + Signers::load_from_data_key(&data_key.key_type, sec_key)?.generate_crl_content( + revoked_keys, + last_update, + next_update, + ) } -} \ No newline at end of file +} diff --git a/src/infra/sign_backend/memory/mod.rs b/src/infra/sign_backend/memory/mod.rs index d882b4c..fceb141 100644 --- a/src/infra/sign_backend/memory/mod.rs +++ b/src/infra/sign_backend/memory/mod.rs @@ -1 +1 @@ -pub mod backend; \ No newline at end of file +pub mod backend; diff --git a/src/infra/sign_plugin/mod.rs b/src/infra/sign_plugin/mod.rs index b3288e8..4fa0dcc 100644 --- a/src/infra/sign_plugin/mod.rs +++ b/src/infra/sign_plugin/mod.rs @@ -1,4 +1,4 @@ pub mod openpgp; -pub mod x509; pub mod signers; pub mod util; +pub mod x509; diff --git a/src/infra/sign_plugin/openpgp.rs b/src/infra/sign_plugin/openpgp.rs index ca24fad..d779921 100644 --- a/src/infra/sign_plugin/openpgp.rs +++ b/src/infra/sign_plugin/openpgp.rs @@ -19,62 +19,82 @@ use crate::domain::sign_plugin::SignPlugins; use crate::util::error::{Error, Result}; use crate::util::options; use chrono::{DateTime, Utc}; -use pgp::composed::signed_key::{SignedSecretKey, SignedPublicKey}; -use pgp::composed::{key::SecretKeyParamsBuilder}; +use pgp::composed::key::SecretKeyParamsBuilder; +use pgp::composed::signed_key::{SignedPublicKey, SignedSecretKey}; use pgp::crypto::{hash::HashAlgorithm, sym::SymmetricKeyAlgorithm}; use pgp::packet::SignatureConfig; -use pgp::packet::{Subpacket, SubpacketData}; use pgp::packet::*; +use pgp::packet::{Subpacket, SubpacketData}; +use super::util::{attributes_validate, validate_utc_time, validate_utc_time_not_expire}; +use crate::domain::datakey::entity::{ + DataKey, DataKeyContent, KeyType as EntityKeyType, RevokedKey, SecDataKey, +}; +use crate::domain::datakey::plugins::openpgp::{ + OpenPGPDigestAlgorithm, OpenPGPKeyType, PGP_VALID_KEY_SIZE, +}; +use crate::util::key::encode_u8_to_hex_string; +#[allow(unused_imports)] +use enum_iterator::all; +use pgp::composed::StandaloneSignature; use pgp::types::KeyTrait; use pgp::types::{CompressionAlgorithm, SecretKeyTrait}; use pgp::Deserializable; use serde::Deserialize; use smallvec::*; use std::collections::HashMap; -use std::io::{Cursor}; +use std::io::Cursor; use std::str::from_utf8; use std::str::FromStr; use validator::{Validate, ValidationError}; -use pgp::composed::StandaloneSignature; -use crate::domain::datakey::entity::{DataKey, DataKeyContent, SecDataKey, KeyType as EntityKeyType, RevokedKey}; -use crate::domain::datakey::plugins::openpgp::{OpenPGPDigestAlgorithm, OpenPGPKeyType, PGP_VALID_KEY_SIZE}; -use crate::util::key::encode_u8_to_hex_string; -use super::util::{validate_utc_time_not_expire, validate_utc_time, attributes_validate}; -#[allow(unused_imports)] -use enum_iterator::{all}; - #[derive(Debug, Validate, Deserialize)] pub struct PgpKeyImportParameter { key_type: OpenPGPKeyType, - #[validate(custom(function = "validate_key_size", message="invalid openpgp attribute 'key_length'"))] + #[validate(custom( + function = "validate_key_size", + message = "invalid openpgp attribute 'key_length'" + ))] key_length: Option, digest_algorithm: OpenPGPDigestAlgorithm, - #[validate(custom(function = "validate_utc_time", message="invalid openpgp attribute 'create_at'"))] + #[validate(custom( + function = "validate_utc_time", + message = "invalid openpgp attribute 'create_at'" + ))] create_at: String, - #[validate(custom(function= "validate_utc_time_not_expire", message="invalid openpgp attribute 'expire_at'"))] + #[validate(custom( + function = "validate_utc_time_not_expire", + message = "invalid openpgp attribute 'expire_at'" + ))] expire_at: String, - passphrase: Option + passphrase: Option, } - #[derive(Debug, Validate, Deserialize)] pub struct PgpKeyGenerationParameter { - #[validate(length(min = 4, max = 100, message="invalid openpgp attribute 'name'"))] + #[validate(length(min = 4, max = 100, message = "invalid openpgp attribute 'name'"))] name: String, // email format validation is disabled due to copr has the case of group prefixed email: `@group#project@copr.com` - #[validate(length(min = 2, max = 100, message="invalid openpgp attribute 'email'"))] + #[validate(length(min = 2, max = 100, message = "invalid openpgp attribute 'email'"))] email: String, key_type: OpenPGPKeyType, - #[validate(custom(function = "validate_key_size", message="invalid openpgp attribute 'key_length'"))] + #[validate(custom( + function = "validate_key_size", + message = "invalid openpgp attribute 'key_length'" + ))] key_length: Option, digest_algorithm: OpenPGPDigestAlgorithm, - #[validate(custom(function = "validate_utc_time", message="invalid openpgp attribute 'create_at'"))] + #[validate(custom( + function = "validate_utc_time", + message = "invalid openpgp attribute 'create_at'" + ))] create_at: String, - #[validate(custom(function= "validate_utc_time_not_expire", message="invalid openpgp attribute 'expire_at'"))] + #[validate(custom( + function = "validate_utc_time_not_expire", + message = "invalid openpgp attribute 'expire_at'" + ))] expire_at: String, - passphrase: Option + passphrase: Option, } impl PgpKeyGenerationParameter { @@ -85,7 +105,9 @@ impl PgpKeyGenerationParameter { fn validate_key_size(key_size: &str) -> std::result::Result<(), ValidationError> { if !PGP_VALID_KEY_SIZE.contains(&key_size) { - return Err(ValidationError::new("invalid key size, possible values are 2048/3072/4096")); + return Err(ValidationError::new( + "invalid key size, possible values are 2048/3072/4096", + )); } Ok(()) } @@ -99,7 +121,9 @@ pub struct OpenPGPPlugin { } impl OpenPGPPlugin { - pub fn attributes_validate(attr: &HashMap) -> Result { + pub fn attributes_validate( + attr: &HashMap, + ) -> Result { let parameter: PgpKeyGenerationParameter = serde_json::from_str(serde_json::to_string(&attr)?.as_str())?; match parameter.validate() { @@ -114,16 +138,18 @@ impl SignPlugins for OpenPGPPlugin { let mut secret_key = None; let mut public_key = None; if !db.private_key.unsecure().is_empty() { - let private = from_utf8(db.private_key.unsecure()).map_err(|e| Error::KeyParseError(e.to_string()))?; - let (value, _) = - SignedSecretKey::from_string(private).map_err(|e| Error::KeyParseError(e.to_string()))?; + let private = from_utf8(db.private_key.unsecure()) + .map_err(|e| Error::KeyParseError(e.to_string()))?; + let (value, _) = SignedSecretKey::from_string(private) + .map_err(|e| Error::KeyParseError(e.to_string()))?; secret_key = Some(value); } if !db.public_key.unsecure().is_empty() { - let public = from_utf8(db.public_key.unsecure()).map_err(|e| Error::KeyParseError(e.to_string()))?; - let (value, _) = - SignedPublicKey::from_string(public).map_err(|e| Error::KeyParseError(e.to_string()))?; + let public = from_utf8(db.public_key.unsecure()) + .map_err(|e| Error::KeyParseError(e.to_string()))?; + let (value, _) = SignedPublicKey::from_string(public) + .map_err(|e| Error::KeyParseError(e.to_string()))?; public_key = Some(value); } Ok(Self { @@ -135,23 +161,25 @@ impl SignPlugins for OpenPGPPlugin { }) } - fn validate_and_update(key: &mut DataKey) -> Result<()> where Self: Sized { + fn validate_and_update(key: &mut DataKey) -> Result<()> + where + Self: Sized, + { let _ = attributes_validate::(&key.attributes)?; //validate keys - let private = from_utf8(&key.private_key).map_err(|e| Error::KeyParseError(e.to_string()))?; - let (secret_key, _) = - SignedSecretKey::from_string(private).map_err(|e| Error::KeyParseError(e.to_string()))?; + let private = + from_utf8(&key.private_key).map_err(|e| Error::KeyParseError(e.to_string()))?; + let (secret_key, _) = SignedSecretKey::from_string(private) + .map_err(|e| Error::KeyParseError(e.to_string()))?; let public = from_utf8(&key.public_key).map_err(|e| Error::KeyParseError(e.to_string()))?; - let (public_key, _) = - SignedPublicKey::from_string(public).map_err(|e| Error::KeyParseError(e.to_string()))?; + let (public_key, _) = SignedPublicKey::from_string(public) + .map_err(|e| Error::KeyParseError(e.to_string()))?; //update key attributes key.fingerprint = encode_u8_to_hex_string(&secret_key.fingerprint()); //NOTE: currently we can not get expire at from openpgp key match public_key.expires_at() { None => {} - Some(time) => { - key.expire_at = time - } + Some(time) => key.expire_at = time, } Ok(()) } @@ -164,14 +192,22 @@ impl SignPlugins for OpenPGPPlugin { todo!() } - fn generate_keys(&self, _key_type: &EntityKeyType, _infra_config: &HashMap) -> Result { + fn generate_keys( + &self, + _key_type: &EntityKeyType, + _infra_config: &HashMap, + ) -> Result { let parameter = attributes_validate::(&self.attributes)?; let mut key_params = SecretKeyParamsBuilder::default(); let create_at = parameter.create_at.parse()?; - let expire :DateTime = parameter.expire_at.parse()?; + let expire: DateTime = parameter.expire_at.parse()?; let duration: core::time::Duration = (expire - Utc::now()).to_std()?; key_params - .key_type(parameter.key_type.get_real_key_type(parameter.key_length.clone())) + .key_type( + parameter + .key_type + .get_real_key_type(parameter.key_length.clone()), + ) .can_create_certificates(false) .can_sign(true) .primary_user_id(parameter.get_user_id()) @@ -182,18 +218,14 @@ impl SignPlugins for OpenPGPPlugin { .expiration(Some(duration)); let secret_key_params = key_params.build()?; let secret_key = secret_key_params.generate()?; - let passwd_fn= || match parameter.passphrase { - None => { - String::new() - } - Some(password) => { - password - } + let passwd_fn = || match parameter.passphrase { + None => String::new(), + Some(password) => password, }; let signed_secret_key = secret_key.sign(passwd_fn.clone())?; let public_key = signed_secret_key.public_key(); let signed_public_key = public_key.sign(&signed_secret_key, passwd_fn)?; - Ok(DataKeyContent{ + Ok(DataKeyContent { private_key: signed_secret_key.to_armored_bytes(None)?, public_key: signed_public_key.to_armored_bytes(None)?, certificate: vec![], @@ -205,15 +237,13 @@ impl SignPlugins for OpenPGPPlugin { fn sign(&self, content: Vec, options: HashMap) -> Result> { let mut digest = HashAlgorithm::SHA2_256; if let Some(digest_str) = options.get("digest_algorithm") { - digest = OpenPGPDigestAlgorithm::from_str(digest_str)?.get_real_algorithm(); + digest = OpenPGPDigestAlgorithm::from_str(digest_str)?.get_real_algorithm(); } - let passwd_fn = || return match options.get("passphrase") { - None => { - String::new() - } - Some(password) => { - password.to_string() - } + let passwd_fn = || { + return match options.get("passphrase") { + None => String::new(), + Some(password) => password.to_string(), + }; }; let now = Utc::now(); let secret_key_id = self.secret_key.clone().unwrap().key_id(); @@ -235,12 +265,11 @@ impl SignPlugins for OpenPGPPlugin { .sign(&self.secret_key.clone().unwrap(), passwd_fn, read_cursor) .map_err(|e| Error::SignError(self.identity.clone(), e.to_string()))?; - //detached signature if let Some(detached) = options.get(options::DETACHED) { if detached == "true" { let standard_signature = StandaloneSignature::new(signature_packet); - return Ok(standard_signature.to_armored_bytes(None)?) + return Ok(standard_signature.to_armored_bytes(None)?); } } let mut signature_bytes = Vec::with_capacity(1024); @@ -250,7 +279,12 @@ impl SignPlugins for OpenPGPPlugin { Ok(signature_bytes) } - fn generate_crl_content(&self, _revoked_keys: Vec, _last_update: DateTime, _next_update: DateTime) -> Result> { + fn generate_crl_content( + &self, + _revoked_keys: Vec, + _last_update: DateTime, + _next_update: DateTime, + ) -> Result> { todo!() } } @@ -258,14 +292,14 @@ impl SignPlugins for OpenPGPPlugin { #[cfg(test)] mod test { use super::*; - use chrono::{Duration, Utc}; - use rand::Rng; - use secstr::SecVec; + use crate::domain::datakey::entity::KeyType; use crate::domain::datakey::entity::{KeyState, Visibility}; - use crate::domain::datakey::entity::{KeyType}; use crate::domain::encryption_engine::EncryptionEngine; use crate::infra::encryption::dummy_engine::DummyEngine; use crate::util::options::DETACHED; + use chrono::{Duration, Utc}; + use rand::Rng; + use secstr::SecVec; fn get_encryption_engine() -> Box { Box::new(DummyEngine::default()) @@ -275,16 +309,22 @@ mod test { HashMap::from([ ("name".to_string(), "fake_name".to_string()), ("email".to_string(), "fake_email@email.com".to_string()), - ("key_type".to_string() ,"rsa".to_string()), + ("key_type".to_string(), "rsa".to_string()), ("key_length".to_string(), "2048".to_string()), ("digest_algorithm".to_string(), "sha2_256".to_string()), ("create_at".to_string(), Utc::now().to_string()), - ("expire_at".to_string(), (Utc::now() + Duration::days(365)).to_string()), + ( + "expire_at".to_string(), + (Utc::now() + Duration::days(365)).to_string(), + ), ("passphrase".to_string(), "123456".to_string()), ]) } - fn get_default_datakey(name: Option, parameter: Option>) -> DataKey { + fn get_default_datakey( + name: Option, + parameter: Option>, + ) -> DataKey { let now = Utc::now(); let mut datakey = DataKey { id: 0, @@ -323,7 +363,8 @@ mod test { parameter.insert("key_type".to_string(), "invalid".to_string()); attributes_validate::(¶meter).expect_err("invalid key type"); parameter.insert("key_type".to_string(), "".to_string()); - attributes_validate::(¶meter).expect_err("invalid empty key type"); + attributes_validate::(¶meter) + .expect_err("invalid empty key type"); for key_type in all::().collect::>() { parameter.insert("key_type".to_string(), key_type.to_string()); attributes_validate::(¶meter).expect("valid key type"); @@ -333,10 +374,12 @@ mod test { #[test] fn test_key_size_generate_parameter() { let mut parameter = get_default_parameter(); - parameter.insert("key_length".to_string(), "1024".to_string()); - attributes_validate::(¶meter).expect_err("invalid key length"); + parameter.insert("key_length".to_string(), "1024".to_string()); + attributes_validate::(¶meter) + .expect_err("invalid key length"); parameter.insert("key_length".to_string(), "".to_string()); - attributes_validate::(¶meter).expect_err("invalid empty key length"); + attributes_validate::(¶meter) + .expect_err("invalid empty key length"); for key_length in PGP_VALID_KEY_SIZE { parameter.insert("key_length".to_string(), key_length.to_string()); attributes_validate::(¶meter).expect("valid key length"); @@ -347,22 +390,27 @@ mod test { fn test_digest_algorithm_generate_parameter() { let mut parameter = get_default_parameter(); parameter.insert("digest_algorithm".to_string(), "1234".to_string()); - attributes_validate::(¶meter).expect_err("invalid digest algorithm"); - parameter.insert("digest_algorithm".to_string(),"".to_string()); - attributes_validate::(¶meter).expect_err("invalid empty digest algorithm"); + attributes_validate::(¶meter) + .expect_err("invalid digest algorithm"); + parameter.insert("digest_algorithm".to_string(), "".to_string()); + attributes_validate::(¶meter) + .expect_err("invalid empty digest algorithm"); for key_length in all::().collect::>() { - parameter.insert("digest_algorithm".to_string(),key_length.to_string()); - attributes_validate::(¶meter).expect("valid digest algorithm"); + parameter.insert("digest_algorithm".to_string(), key_length.to_string()); + attributes_validate::(¶meter) + .expect("valid digest algorithm"); } } #[test] fn test_create_at_generate_parameter() { let mut parameter = get_default_parameter(); - parameter.insert("create_at".to_string(),"1234".to_string()); - attributes_validate::(¶meter).expect_err("invalid create at"); - parameter.insert("create_at".to_string(),"".to_string()); - attributes_validate::(¶meter).expect_err("invalid empty create at"); + parameter.insert("create_at".to_string(), "1234".to_string()); + attributes_validate::(¶meter) + .expect_err("invalid create at"); + parameter.insert("create_at".to_string(), "".to_string()); + attributes_validate::(¶meter) + .expect_err("invalid empty create at"); parameter.insert("create_at".to_string(), Utc::now().to_string()); attributes_validate::(¶meter).expect("valid create at"); } @@ -370,13 +418,22 @@ mod test { #[test] fn test_expire_at_generate_parameter() { let mut parameter = get_default_parameter(); - parameter.insert("expire_at".to_string(),"1234".to_string()); - attributes_validate::(¶meter).expect_err("invalid expire at"); + parameter.insert("expire_at".to_string(), "1234".to_string()); + attributes_validate::(¶meter) + .expect_err("invalid expire at"); parameter.insert("expire_at".to_string(), "".to_string()); - attributes_validate::(¶meter).expect_err("invalid empty expire at"); - parameter.insert("expire_at".to_string(),(Utc::now() - Duration::days(1)).to_string()); - attributes_validate::(¶meter).expect_err("expire at expired"); - parameter.insert("expire_at".to_string(), (Utc::now() + Duration::minutes(1)).to_string()); + attributes_validate::(¶meter) + .expect_err("invalid empty expire at"); + parameter.insert( + "expire_at".to_string(), + (Utc::now() - Duration::days(1)).to_string(), + ); + attributes_validate::(¶meter) + .expect_err("expire at expired"); + parameter.insert( + "expire_at".to_string(), + (Utc::now() + Duration::minutes(1)).to_string(), + ); attributes_validate::(¶meter).expect("valid expire at"); } @@ -384,9 +441,11 @@ mod test { fn test_email_generate_parameter() { let mut parameter = get_default_parameter(); parameter.insert("email".to_string(), "fake".to_string()); - attributes_validate::(¶meter).expect("invalid email should work"); + attributes_validate::(¶meter) + .expect("invalid email should work"); parameter.insert("email".to_string(), "".to_string()); - attributes_validate::(¶meter).expect_err("invalid empty email"); + attributes_validate::(¶meter) + .expect_err("invalid empty email"); parameter.insert("email".to_string(), "tommylikehu@gmail.com".to_string()); attributes_validate::(¶meter).expect("valid email"); } @@ -397,16 +456,23 @@ mod test { let dummy_engine = get_encryption_engine(); let algorithms = all::().collect::>(); //choose 3 random digest algorithm - for _ in [1,2,3] { + for _ in [1, 2, 3] { let num = rand::thread_rng().gen_range(0..algorithms.len()); parameter.insert("digest_algorithm".to_string(), algorithms[num].to_string()); let sec_datakey = SecDataKey::load( - &get_default_datakey( - None, Some(parameter.clone())), &dummy_engine).await.expect("load sec datakey successfully"); - let plugin = OpenPGPPlugin::new(sec_datakey).expect("create openpgp plugin successfully"); - plugin.generate_keys(&KeyType::OpenPGP, &HashMap::new()).expect(format!("generate key with digest {} successfully", algorithms[num]).as_str()); + &get_default_datakey(None, Some(parameter.clone())), + &dummy_engine, + ) + .await + .expect("load sec datakey successfully"); + let plugin = + OpenPGPPlugin::new(sec_datakey).expect("create openpgp plugin successfully"); + plugin + .generate_keys(&KeyType::OpenPGP, &HashMap::new()) + .expect( + format!("generate key with digest {} successfully", algorithms[num]).as_str(), + ); } - } #[tokio::test] @@ -416,10 +482,16 @@ mod test { for key_size in PGP_VALID_KEY_SIZE { parameter.insert("key_size".to_string(), key_size.to_string()); let sec_datakey = SecDataKey::load( - &get_default_datakey( - None, Some(parameter.clone())), &dummy_engine).await.expect("load sec datakey successfully"); - let plugin = OpenPGPPlugin::new(sec_datakey).expect("create openpgp plugin successfully"); - plugin.generate_keys(&KeyType::OpenPGP, &HashMap::new()).expect(format!("generate key with key size {} successfully", key_size).as_str()); + &get_default_datakey(None, Some(parameter.clone())), + &dummy_engine, + ) + .await + .expect("load sec datakey successfully"); + let plugin = + OpenPGPPlugin::new(sec_datakey).expect("create openpgp plugin successfully"); + plugin + .generate_keys(&KeyType::OpenPGP, &HashMap::new()) + .expect(format!("generate key with key size {} successfully", key_size).as_str()); } } @@ -427,13 +499,19 @@ mod test { async fn test_generate_key_with_possible_key_type() { let mut parameter = get_default_parameter(); let dummy_engine = get_encryption_engine(); - for key_type in all::().collect::>(){ + for key_type in all::().collect::>() { parameter.insert("key_type".to_string(), key_type.to_string()); let sec_datakey = SecDataKey::load( - &get_default_datakey( - None, Some(parameter.clone())), &dummy_engine).await.expect("load sec datakey successfully"); - let plugin = OpenPGPPlugin::new(sec_datakey).expect("create openpgp plugin successfully"); - plugin.generate_keys(&KeyType::OpenPGP, &HashMap::new()).expect(format!("generate key with key type {} successfully", key_type).as_str()); + &get_default_datakey(None, Some(parameter.clone())), + &dummy_engine, + ) + .await + .expect("load sec datakey successfully"); + let plugin = + OpenPGPPlugin::new(sec_datakey).expect("create openpgp plugin successfully"); + plugin + .generate_keys(&KeyType::OpenPGP, &HashMap::new()) + .expect(format!("generate key with key type {} successfully", key_type).as_str()); } } @@ -442,16 +520,26 @@ mod test { let mut parameter = get_default_parameter(); let dummy_engine = get_encryption_engine(); let sec_datakey = SecDataKey::load( - &get_default_datakey( - None, Some(parameter.clone())), &dummy_engine).await.expect("load sec datakey successfully"); + &get_default_datakey(None, Some(parameter.clone())), + &dummy_engine, + ) + .await + .expect("load sec datakey successfully"); let plugin = OpenPGPPlugin::new(sec_datakey).expect("create openpgp plugin successfully"); - plugin.generate_keys(&KeyType::OpenPGP, &HashMap::new()).expect(format!("generate key with no passphrase successfully").as_str()); + plugin + .generate_keys(&KeyType::OpenPGP, &HashMap::new()) + .expect(format!("generate key with no passphrase successfully").as_str()); parameter.insert("passphrase".to_string(), "".to_string()); let sec_datakey = SecDataKey::load( - &get_default_datakey( - None, Some(parameter.clone())), &dummy_engine).await.expect("load sec datakey successfully"); + &get_default_datakey(None, Some(parameter.clone())), + &dummy_engine, + ) + .await + .expect("load sec datakey successfully"); let plugin = OpenPGPPlugin::new(sec_datakey).expect("create openpgp plugin successfully"); - plugin.generate_keys(&KeyType::OpenPGP, &HashMap::new()).expect(format!("generate key with passphrase successfully").as_str()); + plugin + .generate_keys(&KeyType::OpenPGP, &HashMap::new()) + .expect(format!("generate key with passphrase successfully").as_str()); } #[test] @@ -527,7 +615,10 @@ j0AZp6WE datakey.public_key = public_key.as_bytes().to_vec(); datakey.private_key = private_key.as_bytes().to_vec(); OpenPGPPlugin::validate_and_update(&mut datakey).expect("validate and update should work"); - assert_eq!("60780E80350801A395B1B08302A5B5FB87CD058E", datakey.fingerprint); + assert_eq!( + "60780E80350801A395B1B08302A5B5FB87CD058E", + datakey.fingerprint + ); } #[tokio::test] @@ -537,10 +628,15 @@ j0AZp6WE parameter.insert(DETACHED.to_string(), "true".to_string()); let dummy_engine = get_encryption_engine(); let sec_datakey = SecDataKey::load( - &get_default_datakey( - None, Some(parameter.clone())), &dummy_engine).await.expect("load sec datakey successfully"); + &get_default_datakey(None, Some(parameter.clone())), + &dummy_engine, + ) + .await + .expect("load sec datakey successfully"); let plugin = OpenPGPPlugin::new(sec_datakey).expect("create openpgp plugin successfully"); - let keys = plugin.generate_keys(&KeyType::OpenPGP, &HashMap::new()).expect(format!("generate key successfully").as_str()); + let keys = plugin + .generate_keys(&KeyType::OpenPGP, &HashMap::new()) + .expect(format!("generate key successfully").as_str()); let sec_keys = SecDataKey { name: "".to_string(), private_key: SecVec::new(keys.private_key.clone()), @@ -551,13 +647,22 @@ j0AZp6WE parent: None, }; let instance = OpenPGPPlugin::new(sec_keys).expect("create openpgp instance successfully"); - let signature = instance.sign(content.to_vec(), parameter).expect("sign successfully"); + let signature = instance + .sign(content.to_vec(), parameter) + .expect("sign successfully"); let signature_text = from_utf8(&signature).expect("signature bytes to string should work"); - assert_eq!(true, signature_text.contains("-----BEGIN PGP SIGNATURE-----")); + assert_eq!( + true, + signature_text.contains("-----BEGIN PGP SIGNATURE-----") + ); assert_eq!(true, signature_text.contains("-----END PGP SIGNATURE-----")); - let (standalone, _) = StandaloneSignature::from_string(signature_text).expect("parse signature successfully"); + let (standalone, _) = + StandaloneSignature::from_string(signature_text).expect("parse signature successfully"); let public = from_utf8(&keys.public_key).expect("parse public key should work"); - let (public_key, _) = SignedPublicKey::from_string(public).expect("parse signed public key should work"); - standalone.verify(&public_key, content).expect("signature matches"); + let (public_key, _) = + SignedPublicKey::from_string(public).expect("parse signed public key should work"); + standalone + .verify(&public_key, content) + .expect("signature matches"); } } diff --git a/src/infra/sign_plugin/signers.rs b/src/infra/sign_plugin/signers.rs index 0576be6..fe791f2 100644 --- a/src/infra/sign_plugin/signers.rs +++ b/src/infra/sign_plugin/signers.rs @@ -14,10 +14,10 @@ * */ +use crate::domain::datakey::entity::{DataKey, KeyType}; use crate::domain::sign_plugin::SignPlugins; use crate::infra::sign_plugin::openpgp::OpenPGPPlugin; use crate::infra::sign_plugin::x509::X509Plugin; -use crate::domain::datakey::entity::{DataKey, KeyType}; use crate::util::error::Result; use crate::domain::datakey::entity::SecDataKey; @@ -25,20 +25,25 @@ use crate::domain::datakey::entity::SecDataKey; pub struct Signers {} impl Signers { - //get responding sign plugin for data signing - pub fn load_from_data_key(key_type: &KeyType, data_key: SecDataKey) -> Result> { + pub fn load_from_data_key( + key_type: &KeyType, + data_key: SecDataKey, + ) -> Result> { match key_type { KeyType::OpenPGP => Ok(Box::new(OpenPGPPlugin::new(data_key)?)), - KeyType::X509CA | KeyType::X509ICA | KeyType::X509EE => Ok(Box::new(X509Plugin::new(data_key)?)), + KeyType::X509CA | KeyType::X509ICA | KeyType::X509EE => { + Ok(Box::new(X509Plugin::new(data_key)?)) + } } } - pub fn validate_and_update(datakey: &mut DataKey) -> Result<()> { match datakey.key_type { KeyType::OpenPGP => OpenPGPPlugin::validate_and_update(datakey), - KeyType::X509CA | KeyType::X509ICA | KeyType::X509EE => X509Plugin::validate_and_update(datakey), + KeyType::X509CA | KeyType::X509ICA | KeyType::X509EE => { + X509Plugin::validate_and_update(datakey) + } } } } diff --git a/src/infra/sign_plugin/util.rs b/src/infra/sign_plugin/util.rs index 597fd67..92c9659 100644 --- a/src/infra/sign_plugin/util.rs +++ b/src/infra/sign_plugin/util.rs @@ -11,43 +11,38 @@ * // NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. * // See the Mulan PSL v2 for more details. */ -use validator::{Validate, ValidationError}; +use crate::util::error::{Error, Result as CommonResult}; use chrono::{DateTime, Utc}; use serde::Deserialize; -use crate::util::error::{Error, Result as CommonResult}; use std::collections::HashMap; +use validator::{Validate, ValidationError}; pub fn validate_utc_time_not_expire(expire: &str) -> Result<(), ValidationError> { let now = Utc::now(); match expire.parse::>() { Ok(expire) => { if expire <= now { - return Err(ValidationError::new("expire time less than current time")) + return Err(ValidationError::new("expire time less than current time")); } Ok(()) - }, - Err(_e) => { - Err(ValidationError::new("failed to parse time string to utc")) } + Err(_e) => Err(ValidationError::new("failed to parse time string to utc")), } } pub fn validate_utc_time(expire: &str) -> Result<(), ValidationError> { match expire.parse::>() { - Ok(_) => { - Ok(()) - }, - Err(_) => { - Err(ValidationError::new("failed to parse time string to utc")) - } + Ok(_) => Ok(()), + Err(_) => Err(ValidationError::new("failed to parse time string to utc")), } } -pub fn attributes_validate Deserialize<'a>>(attr: &HashMap) -> CommonResult { - let parameter:T = - serde_json::from_str(serde_json::to_string(attr)?.as_str())?; +pub fn attributes_validate Deserialize<'a>>( + attr: &HashMap, +) -> CommonResult { + let parameter: T = serde_json::from_str(serde_json::to_string(attr)?.as_str())?; match parameter.validate() { Ok(_) => Ok(parameter), Err(e) => Err(Error::ParameterError(format!("{:?}", e))), } -} \ No newline at end of file +} diff --git a/src/infra/sign_plugin/x509.rs b/src/infra/sign_plugin/x509.rs index 2cd290e..2362b7e 100644 --- a/src/infra/sign_plugin/x509.rs +++ b/src/infra/sign_plugin/x509.rs @@ -16,75 +16,110 @@ use std::collections::HashMap; use std::str::FromStr; -use std::time::{SystemTime, Duration}; +use std::time::{Duration, SystemTime}; use chrono::{DateTime, Utc}; +use foreign_types_shared::{ForeignType, ForeignTypeRef}; use openssl::asn1::{Asn1Integer, Asn1Time}; use openssl::bn::{BigNum, MsbOption}; -use openssl::cms::{CmsContentInfo, CMSOptions}; +use openssl::cms::{CMSOptions, CmsContentInfo}; use openssl::hash::MessageDigest; use openssl::nid::Nid; use openssl::pkcs7::{Pkcs7, Pkcs7Flags}; -use openssl::pkey::{PKey}; -use openssl::stack::{Stack}; +use openssl::pkey::PKey; +use openssl::stack::Stack; use openssl::x509; -use openssl::x509::extension::{AuthorityKeyIdentifier, BasicConstraints, KeyUsage, SubjectKeyIdentifier}; +use openssl::x509::extension::{ + AuthorityKeyIdentifier, BasicConstraints, KeyUsage, SubjectKeyIdentifier, +}; use openssl::x509::{X509Crl, X509Extension}; +use openssl_sys::{ + X509_CRL_add0_revoked, X509_CRL_new, X509_CRL_set1_lastUpdate, X509_CRL_set1_nextUpdate, + X509_CRL_set_issuer_name, X509_CRL_sign, X509_REVOKED_new, X509_REVOKED_set_revocationDate, + X509_REVOKED_set_serialNumber, +}; use secstr::SecVec; use serde::Deserialize; -use foreign_types_shared::{ForeignType, ForeignTypeRef}; -use openssl_sys::{X509_CRL_new, X509_CRL_set_issuer_name, X509_CRL_set1_lastUpdate, X509_CRL_add0_revoked, X509_CRL_sign, X509_CRL_set1_nextUpdate, X509_REVOKED_new, X509_REVOKED_set_serialNumber, X509_REVOKED_set_revocationDate}; -use validator::{Validate, ValidationError}; -use crate::util::options; -use crate::util::sign::SignType; -use crate::domain::datakey::entity::{DataKey, DataKeyContent, INFRA_CONFIG_DOMAIN_NAME, KeyType, RevokedKey, SecDataKey, SecParentDateKey}; -use crate::domain::datakey::plugins::x509::{X509KeyType, X509_VALID_KEY_SIZE, X509DigestAlgorithm}; -use crate::util::error::{Error, Result}; +use super::util::{attributes_validate, validate_utc_time, validate_utc_time_not_expire}; +use crate::domain::datakey::entity::{ + DataKey, DataKeyContent, KeyType, RevokedKey, SecDataKey, SecParentDateKey, + INFRA_CONFIG_DOMAIN_NAME, +}; +use crate::domain::datakey::plugins::x509::{ + X509DigestAlgorithm, X509KeyType, X509_VALID_KEY_SIZE, +}; use crate::domain::sign_plugin::SignPlugins; +use crate::util::error::{Error, Result}; use crate::util::key::{decode_hex_string_to_u8, encode_u8_to_hex_string}; -use super::util::{validate_utc_time_not_expire, validate_utc_time, attributes_validate}; +use crate::util::options; +use crate::util::sign::SignType; #[allow(unused_imports)] -use enum_iterator::{all}; +use enum_iterator::all; +use validator::{Validate, ValidationError}; #[derive(Debug, Validate, Deserialize)] pub struct X509KeyGenerationParameter { - #[validate(length(min = 1, max = 30, message="invalid x509 subject 'CommonName'"))] + #[validate(length(min = 1, max = 30, message = "invalid x509 subject 'CommonName'"))] common_name: String, - #[validate(length(min = 1, max = 30, message="invalid x509 subject 'OrganizationalUnit'"))] + #[validate(length( + min = 1, + max = 30, + message = "invalid x509 subject 'OrganizationalUnit'" + ))] organizational_unit: String, - #[validate(length(min = 1, max = 30, message="invalid x509 subject 'Organization'"))] + #[validate(length(min = 1, max = 30, message = "invalid x509 subject 'Organization'"))] organization: String, - #[validate(length(min = 1, max = 30, message="invalid x509 subject 'Locality'"))] + #[validate(length(min = 1, max = 30, message = "invalid x509 subject 'Locality'"))] locality: String, - #[validate(length(min = 1, max = 30, message="invalid x509 subject 'StateOrProvinceName'"))] + #[validate(length( + min = 1, + max = 30, + message = "invalid x509 subject 'StateOrProvinceName'" + ))] province_name: String, - #[validate(length(min = 2, max = 2, message="invalid x509 subject 'CountryName'"))] + #[validate(length(min = 2, max = 2, message = "invalid x509 subject 'CountryName'"))] country_name: String, key_type: X509KeyType, - #[validate(custom(function = "validate_x509_key_size", message="invalid x509 attribute 'key_length'"))] + #[validate(custom( + function = "validate_x509_key_size", + message = "invalid x509 attribute 'key_length'" + ))] key_length: String, digest_algorithm: X509DigestAlgorithm, - #[validate(custom(function = "validate_utc_time", message="invalid x509 attribute 'create_at'"))] + #[validate(custom( + function = "validate_utc_time", + message = "invalid x509 attribute 'create_at'" + ))] create_at: String, - #[validate(custom(function= "validate_utc_time_not_expire", message="invalid x509 attribute 'expire_at'"))] + #[validate(custom( + function = "validate_utc_time_not_expire", + message = "invalid x509 attribute 'expire_at'" + ))] expire_at: String, } #[derive(Debug, Validate, Deserialize)] pub struct X509KeyImportParameter { key_type: X509KeyType, - #[validate(custom(function = "validate_x509_key_size", message="invalid x509 attribute 'key_length'"))] + #[validate(custom( + function = "validate_x509_key_size", + message = "invalid x509 attribute 'key_length'" + ))] key_length: String, digest_algorithm: X509DigestAlgorithm, - #[validate(custom(function = "validate_utc_time", message="invalid x509 attribute 'create_at'"))] + #[validate(custom( + function = "validate_utc_time", + message = "invalid x509 attribute 'create_at'" + ))] create_at: String, - #[validate(custom(function= "validate_utc_time_not_expire", message="invalid x509 attribute 'expire_at'"))] + #[validate(custom( + function = "validate_utc_time_not_expire", + message = "invalid x509 attribute 'expire_at'" + ))] expire_at: String, } - - impl X509KeyGenerationParameter { pub fn get_subject_name(&self) -> Result { let mut x509_name = x509::X509NameBuilder::new()?; @@ -100,7 +135,9 @@ impl X509KeyGenerationParameter { fn validate_x509_key_size(key_size: &str) -> std::result::Result<(), ValidationError> { if !X509_VALID_KEY_SIZE.contains(&key_size) { - return Err(ValidationError::new("invalid key size, possible values are 2048/3072/4096")); + return Err(ValidationError::new( + "invalid key size, possible values are 2048/3072/4096", + )); } Ok(()) } @@ -118,21 +155,32 @@ pub struct X509Plugin { certificate: SecVec, identity: String, attributes: HashMap, - parent_key: Option + parent_key: Option, } impl X509Plugin { - fn generate_serial_number() -> Result { let mut serial_number = BigNum::new()?; serial_number.rand(128, MsbOption::MAYBE_ZERO, true)?; Ok(serial_number) } - fn generate_crl_endpoint(&self, name: &str, infra_config: &HashMap) -> Result{ - let domain_name = infra_config.get(INFRA_CONFIG_DOMAIN_NAME).ok_or( - Error::GeneratingKeyError(format!("{} is not configured", INFRA_CONFIG_DOMAIN_NAME)))?; - Ok(format!("URI:https://{}/api/v1/keys/{}/crl", domain_name, name)) + fn generate_crl_endpoint( + &self, + name: &str, + infra_config: &HashMap, + ) -> Result { + let domain_name = + infra_config + .get(INFRA_CONFIG_DOMAIN_NAME) + .ok_or(Error::GeneratingKeyError(format!( + "{} is not configured", + INFRA_CONFIG_DOMAIN_NAME + )))?; + Ok(format!( + "URI:https://{}/api/v1/keys/{}/crl", + domain_name, name + )) } //The openssl config for ca would be like: @@ -144,10 +192,15 @@ impl X509Plugin { // nsCertType = objCA // nsComment = "Signatrust Root CA" #[allow(deprecated)] - fn generate_x509ca_keys(&self, _infra_config: &HashMap) -> Result { + fn generate_x509ca_keys( + &self, + _infra_config: &HashMap, + ) -> Result { let parameter = attributes_validate::(&self.attributes)?; //generate self signed certificate - let keys = parameter.key_type.get_real_key_type(parameter.key_length.parse()?)?; + let keys = parameter + .key_type + .get_real_key_type(parameter.key_length.parse()?)?; let mut generator = x509::X509Builder::new()?; let serial_number = X509Plugin::generate_serial_number()?; generator.set_subject_name(parameter.get_subject_name()?.as_ref())?; @@ -155,24 +208,61 @@ impl X509Plugin { generator.set_pubkey(keys.as_ref())?; generator.set_version(2)?; generator.set_serial_number(Asn1Integer::from_bn(serial_number.as_ref())?.as_ref())?; - generator.set_not_before(Asn1Time::days_from_now(days_in_duration(¶meter.create_at)? as u32)?.as_ref())?; - generator.set_not_after(Asn1Time::days_from_now(days_in_duration(¶meter.expire_at)? as u32)?.as_ref())?; + generator.set_not_before( + Asn1Time::days_from_now(days_in_duration(¶meter.create_at)? as u32)?.as_ref(), + )?; + generator.set_not_after( + Asn1Time::days_from_now(days_in_duration(¶meter.expire_at)? as u32)?.as_ref(), + )?; //ca profile generator.append_extension(BasicConstraints::new().ca().pathlen(1).critical().build()?)?; - generator.append_extension(SubjectKeyIdentifier::new().build(&generator.x509v3_context(None, None))?)?; - generator.append_extension(AuthorityKeyIdentifier::new().keyid(true).issuer(true).build(&generator.x509v3_context(None, None))?)?; - generator.append_extension(KeyUsage::new().crl_sign().digital_signature().key_cert_sign().critical().build()?)?; - generator.append_extension(X509Extension::new_nid(None, None, Nid::NETSCAPE_COMMENT, "Signatrust Root CA")?)?; - generator.append_extension(X509Extension::new_nid(None, None, Nid::NETSCAPE_CERT_TYPE, "objCA")?)?; + generator.append_extension( + SubjectKeyIdentifier::new().build(&generator.x509v3_context(None, None))?, + )?; + generator.append_extension( + AuthorityKeyIdentifier::new() + .keyid(true) + .issuer(true) + .build(&generator.x509v3_context(None, None))?, + )?; + generator.append_extension( + KeyUsage::new() + .crl_sign() + .digital_signature() + .key_cert_sign() + .critical() + .build()?, + )?; + generator.append_extension(X509Extension::new_nid( + None, + None, + Nid::NETSCAPE_COMMENT, + "Signatrust Root CA", + )?)?; + generator.append_extension(X509Extension::new_nid( + None, + None, + Nid::NETSCAPE_CERT_TYPE, + "objCA", + )?)?; - generator.sign(keys.as_ref(), parameter.digest_algorithm.get_real_algorithm())?; + generator.sign( + keys.as_ref(), + parameter.digest_algorithm.get_real_algorithm(), + )?; let cert = generator.build(); - Ok(DataKeyContent{ + Ok(DataKeyContent { private_key: keys.private_key_to_pem_pkcs8()?, public_key: keys.public_key_to_pem()?, certificate: cert.to_pem()?, - fingerprint: encode_u8_to_hex_string(cert.digest( - MessageDigest::from_name("sha1").ok_or(Error::GeneratingKeyError("unable to generate digester".to_string()))?)?.as_ref()), + fingerprint: encode_u8_to_hex_string( + cert.digest( + MessageDigest::from_name("sha1").ok_or(Error::GeneratingKeyError( + "unable to generate digester".to_string(), + ))?, + )? + .as_ref(), + ), serial_number: Some(encode_u8_to_hex_string(&serial_number.to_vec())), }) } @@ -187,16 +277,25 @@ impl X509Plugin { // nsCertType = objCA // nsComment = "Signatrust Intermediate CA" #[allow(deprecated)] - fn generate_x509ica_keys(&self, infra_config: &HashMap) -> Result { + fn generate_x509ica_keys( + &self, + infra_config: &HashMap, + ) -> Result { let parameter = attributes_validate::(&self.attributes)?; //load the ca certificate and private key if self.parent_key.is_none() { - return Err(Error::GeneratingKeyError("parent key is not provided".to_string())); + return Err(Error::GeneratingKeyError( + "parent key is not provided".to_string(), + )); } - let ca_key = PKey::private_key_from_pem(self.parent_key.clone().unwrap().private_key.unsecure())?; - let ca_cert = x509::X509::from_pem(self.parent_key.clone().unwrap().certificate.unsecure())?; + let ca_key = + PKey::private_key_from_pem(self.parent_key.clone().unwrap().private_key.unsecure())?; + let ca_cert = + x509::X509::from_pem(self.parent_key.clone().unwrap().certificate.unsecure())?; //generate self signed certificate - let keys = parameter.key_type.get_real_key_type(parameter.key_length.parse()?)?; + let keys = parameter + .key_type + .get_real_key_type(parameter.key_length.parse()?)?; let mut generator = x509::X509Builder::new()?; let serial_number = X509Plugin::generate_serial_number()?; generator.set_subject_name(parameter.get_subject_name()?.as_ref())?; @@ -204,25 +303,68 @@ impl X509Plugin { generator.set_pubkey(keys.as_ref())?; generator.set_version(2)?; generator.set_serial_number(Asn1Integer::from_bn(serial_number.as_ref())?.as_ref())?; - generator.set_not_before(Asn1Time::days_from_now(days_in_duration(¶meter.create_at)? as u32)?.as_ref())?; - generator.set_not_after(Asn1Time::days_from_now(days_in_duration(¶meter.expire_at)? as u32)?.as_ref())?; + generator.set_not_before( + Asn1Time::days_from_now(days_in_duration(¶meter.create_at)? as u32)?.as_ref(), + )?; + generator.set_not_after( + Asn1Time::days_from_now(days_in_duration(¶meter.expire_at)? as u32)?.as_ref(), + )?; //ca profile generator.append_extension(BasicConstraints::new().ca().pathlen(0).critical().build()?)?; - generator.append_extension(SubjectKeyIdentifier::new().build(&generator.x509v3_context(Some(ca_cert.as_ref()), None))?)?; - generator.append_extension(AuthorityKeyIdentifier::new().keyid(true).issuer(true).build(&generator.x509v3_context(Some(ca_cert.as_ref()), None))?)?; - generator.append_extension(KeyUsage::new().crl_sign().digital_signature().key_cert_sign().critical().build()?)?; - generator.append_extension(X509Extension::new_nid(None, None, Nid::CRL_DISTRIBUTION_POINTS, &self.generate_crl_endpoint(&self.parent_key.clone().unwrap().name, infra_config)?)?)?; - generator.append_extension(X509Extension::new_nid(None, None, Nid::NETSCAPE_COMMENT, "Signatrust Intermediate CA")?)?; - generator.append_extension(X509Extension::new_nid(None, None, Nid::NETSCAPE_CERT_TYPE, "objCA")?)?; - generator.sign(ca_key.as_ref(), parameter.digest_algorithm.get_real_algorithm())?; + generator.append_extension( + SubjectKeyIdentifier::new() + .build(&generator.x509v3_context(Some(ca_cert.as_ref()), None))?, + )?; + generator.append_extension( + AuthorityKeyIdentifier::new() + .keyid(true) + .issuer(true) + .build(&generator.x509v3_context(Some(ca_cert.as_ref()), None))?, + )?; + generator.append_extension( + KeyUsage::new() + .crl_sign() + .digital_signature() + .key_cert_sign() + .critical() + .build()?, + )?; + generator.append_extension(X509Extension::new_nid( + None, + None, + Nid::CRL_DISTRIBUTION_POINTS, + &self.generate_crl_endpoint(&self.parent_key.clone().unwrap().name, infra_config)?, + )?)?; + generator.append_extension(X509Extension::new_nid( + None, + None, + Nid::NETSCAPE_COMMENT, + "Signatrust Intermediate CA", + )?)?; + generator.append_extension(X509Extension::new_nid( + None, + None, + Nid::NETSCAPE_CERT_TYPE, + "objCA", + )?)?; + generator.sign( + ca_key.as_ref(), + parameter.digest_algorithm.get_real_algorithm(), + )?; let cert = generator.build(); //use parent private key to sign the certificate - Ok(DataKeyContent{ + Ok(DataKeyContent { private_key: keys.private_key_to_pem_pkcs8()?, public_key: keys.public_key_to_pem()?, certificate: cert.to_pem()?, - fingerprint: encode_u8_to_hex_string(cert.digest( - MessageDigest::from_name("sha1").ok_or(Error::GeneratingKeyError("unable to generate digester".to_string()))?)?.as_ref()), + fingerprint: encode_u8_to_hex_string( + cert.digest( + MessageDigest::from_name("sha1").ok_or(Error::GeneratingKeyError( + "unable to generate digester".to_string(), + ))?, + )? + .as_ref(), + ), serial_number: Some(encode_u8_to_hex_string(&serial_number.to_vec())), }) } @@ -238,16 +380,25 @@ impl X509Plugin { // nsCertType = objsign // nsComment = "Signatrust Sign Certificate" #[allow(deprecated)] - fn generate_x509ee_keys(&self, infra_config: &HashMap) -> Result { + fn generate_x509ee_keys( + &self, + infra_config: &HashMap, + ) -> Result { let parameter = attributes_validate::(&self.attributes)?; //load the ca certificate and private key if self.parent_key.is_none() { - return Err(Error::GeneratingKeyError("parent key is not provided".to_string())); + return Err(Error::GeneratingKeyError( + "parent key is not provided".to_string(), + )); } - let ca_key = PKey::private_key_from_pem(self.parent_key.clone().unwrap().private_key.unsecure())?; - let ca_cert = x509::X509::from_pem(self.parent_key.clone().unwrap().certificate.unsecure())?; + let ca_key = + PKey::private_key_from_pem(self.parent_key.clone().unwrap().private_key.unsecure())?; + let ca_cert = + x509::X509::from_pem(self.parent_key.clone().unwrap().certificate.unsecure())?; //generate self signed certificate - let keys = parameter.key_type.get_real_key_type(parameter.key_length.parse()?)?; + let keys = parameter + .key_type + .get_real_key_type(parameter.key_length.parse()?)?; let mut generator = x509::X509Builder::new()?; let serial_number = X509Plugin::generate_serial_number()?; generator.set_subject_name(parameter.get_subject_name()?.as_ref())?; @@ -255,25 +406,68 @@ impl X509Plugin { generator.set_pubkey(keys.as_ref())?; generator.set_version(2)?; generator.set_serial_number(Asn1Integer::from_bn(serial_number.as_ref())?.as_ref())?; - generator.set_not_before(Asn1Time::days_from_now(days_in_duration(¶meter.create_at)? as u32)?.as_ref())?; - generator.set_not_after(Asn1Time::days_from_now(days_in_duration(¶meter.expire_at)? as u32)?.as_ref())?; + generator.set_not_before( + Asn1Time::days_from_now(days_in_duration(¶meter.create_at)? as u32)?.as_ref(), + )?; + generator.set_not_after( + Asn1Time::days_from_now(days_in_duration(¶meter.expire_at)? as u32)?.as_ref(), + )?; //ca profile generator.append_extension(BasicConstraints::new().critical().build()?)?; - generator.append_extension(SubjectKeyIdentifier::new().build(&generator.x509v3_context(Some(ca_cert.as_ref()), None))?)?; - generator.append_extension(AuthorityKeyIdentifier::new().keyid(true).issuer(true).build(&generator.x509v3_context(Some(ca_cert.as_ref()), None))?)?; - generator.append_extension(KeyUsage::new().crl_sign().digital_signature().key_cert_sign().critical().build()?)?; - generator.append_extension(X509Extension::new_nid(None, None, Nid::CRL_DISTRIBUTION_POINTS, &self.generate_crl_endpoint(&self.parent_key.clone().unwrap().name, infra_config)?)?)?; - generator.append_extension(X509Extension::new_nid(None, None, Nid::NETSCAPE_COMMENT, "Signatrust Sign Certificate")?)?; - generator.append_extension(X509Extension::new_nid(None, None, Nid::NETSCAPE_CERT_TYPE, "objsign")?)?; - generator.sign(ca_key.as_ref(), parameter.digest_algorithm.get_real_algorithm())?; + generator.append_extension( + SubjectKeyIdentifier::new() + .build(&generator.x509v3_context(Some(ca_cert.as_ref()), None))?, + )?; + generator.append_extension( + AuthorityKeyIdentifier::new() + .keyid(true) + .issuer(true) + .build(&generator.x509v3_context(Some(ca_cert.as_ref()), None))?, + )?; + generator.append_extension( + KeyUsage::new() + .crl_sign() + .digital_signature() + .key_cert_sign() + .critical() + .build()?, + )?; + generator.append_extension(X509Extension::new_nid( + None, + None, + Nid::CRL_DISTRIBUTION_POINTS, + &self.generate_crl_endpoint(&self.parent_key.clone().unwrap().name, infra_config)?, + )?)?; + generator.append_extension(X509Extension::new_nid( + None, + None, + Nid::NETSCAPE_COMMENT, + "Signatrust Sign Certificate", + )?)?; + generator.append_extension(X509Extension::new_nid( + None, + None, + Nid::NETSCAPE_CERT_TYPE, + "objsign", + )?)?; + generator.sign( + ca_key.as_ref(), + parameter.digest_algorithm.get_real_algorithm(), + )?; let cert = generator.build(); //use parent private key to sign the certificate - Ok(DataKeyContent{ + Ok(DataKeyContent { private_key: keys.private_key_to_pem_pkcs8()?, public_key: keys.public_key_to_pem()?, certificate: cert.to_pem()?, - fingerprint: encode_u8_to_hex_string(cert.digest( - MessageDigest::from_name("sha1").ok_or(Error::GeneratingKeyError("unable to generate digester".to_string()))?)?.as_ref()), + fingerprint: encode_u8_to_hex_string( + cert.digest( + MessageDigest::from_name("sha1").ok_or(Error::GeneratingKeyError( + "unable to generate digester".to_string(), + ))?, + )? + .as_ref(), + ), serial_number: Some(encode_u8_to_hex_string(&serial_number.to_vec())), }) } @@ -296,7 +490,10 @@ impl SignPlugins for X509Plugin { Ok(plugin) } - fn validate_and_update(key: &mut DataKey) -> Result<()> where Self: Sized { + fn validate_and_update(key: &mut DataKey) -> Result<()> + where + Self: Sized, + { let _ = attributes_validate::(&key.attributes)?; let _private_key = PKey::private_key_from_pem(&key.private_key)?; let certificate = x509::X509::from_pem(&key.certificate)?; @@ -304,10 +501,18 @@ impl SignPlugins for X509Plugin { let _public_key = PKey::public_key_from_pem(&key.public_key)?; } let unix_time = Asn1Time::from_unix(0)?.diff(certificate.not_after())?; - let expire = SystemTime::UNIX_EPOCH + Duration::from_secs(unix_time.days as u64 * 86400 + unix_time.secs as u64); + let expire = SystemTime::UNIX_EPOCH + + Duration::from_secs(unix_time.days as u64 * 86400 + unix_time.secs as u64); key.expire_at = expire.into(); key.fingerprint = encode_u8_to_hex_string( - certificate.digest(MessageDigest::from_name("sha1").ok_or(Error::GeneratingKeyError("unable to generate digester".to_string()))?)?.as_ref()); + certificate + .digest( + MessageDigest::from_name("sha1").ok_or(Error::GeneratingKeyError( + "unable to generate digester".to_string(), + ))?, + )? + .as_ref(), + ); Ok(()) } @@ -319,12 +524,18 @@ impl SignPlugins for X509Plugin { todo!() } - fn generate_keys(&self, key_type: &KeyType, infra_config: &HashMap) -> Result { + fn generate_keys( + &self, + key_type: &KeyType, + infra_config: &HashMap, + ) -> Result { match key_type { - KeyType::X509CA => { self.generate_x509ca_keys(infra_config) } - KeyType::X509ICA => { self.generate_x509ica_keys(infra_config) } - KeyType::X509EE => { self.generate_x509ee_keys(infra_config) } - _ => { Err(Error::GeneratingKeyError("x509 plugin only support x509ca, x509ica and x509ee key type".to_string())) } + KeyType::X509CA => self.generate_x509ca_keys(infra_config), + KeyType::X509ICA => self.generate_x509ica_keys(infra_config), + KeyType::X509EE => self.generate_x509ee_keys(infra_config), + _ => Err(Error::GeneratingKeyError( + "x509 plugin only support x509ca, x509ica and x509ee key type".to_string(), + )), } } @@ -334,9 +545,15 @@ impl SignPlugins for X509Plugin { let mut cert_stack = Stack::new()?; cert_stack.push(certificate.clone())?; if self.parent_key.is_some() { - cert_stack.push(x509::X509::from_pem(self.parent_key.clone().unwrap().certificate.unsecure())?)?; + cert_stack.push(x509::X509::from_pem( + self.parent_key.clone().unwrap().certificate.unsecure(), + )?)?; } - match SignType::from_str(options.get(options::SIGN_TYPE).unwrap_or(&SignType::Cms.to_string()))? { + match SignType::from_str( + options + .get(options::SIGN_TYPE) + .unwrap_or(&SignType::Cms.to_string()), + )? { SignType::Authenticode => { let p7b = efi_signer::EfiImage::pem_to_p7(self.certificate.unsecure())?; Ok(efi_signer::EfiImage::do_sign_signature( @@ -344,7 +561,9 @@ impl SignPlugins for X509Plugin { p7b, private_key.private_key_to_pem_pkcs8()?, None, - efi_signer::DigestAlgorithm::Sha256)?.encode()?) + efi_signer::DigestAlgorithm::Sha256, + )? + .encode()?) } SignType::PKCS7 => { let pkcs7 = Pkcs7::sign( @@ -355,7 +574,7 @@ impl SignPlugins for X509Plugin { Pkcs7Flags::DETACHED | Pkcs7Flags::NOCERTS | Pkcs7Flags::BINARY - | Pkcs7Flags::NOSMIMECAP + | Pkcs7Flags::NOSMIMECAP, )?; Ok(pkcs7.to_der()?) } @@ -369,56 +588,91 @@ impl SignPlugins for X509Plugin { CMSOptions::DETACHED | CMSOptions::CMS_NOCERTS | CMSOptions::BINARY - | CMSOptions::NOSMIMECAP + | CMSOptions::NOSMIMECAP, )?; Ok(cms_signature.to_der()?) } } } - fn generate_crl_content(&self, revoked_keys: Vec, last_update: DateTime, next_update: DateTime) -> Result> { + fn generate_crl_content( + &self, + revoked_keys: Vec, + last_update: DateTime, + next_update: DateTime, + ) -> Result> { let parameter = attributes_validate::(&self.attributes)?; let private_key = PKey::private_key_from_pem(self.private_key.unsecure())?; let certificate = x509::X509::from_pem(self.certificate.unsecure())?; //prepare raw crl content - let crl = unsafe{ X509_CRL_new() }; - let x509_name= certificate.subject_name().as_ptr(); + let crl = unsafe { X509_CRL_new() }; + let x509_name = certificate.subject_name().as_ptr(); - unsafe { X509_CRL_set_issuer_name(crl, x509_name); }; - unsafe {X509_CRL_set1_lastUpdate(crl, Asn1Time::from_unix(last_update.naive_utc().timestamp())?.as_ptr())}; - unsafe {X509_CRL_set1_nextUpdate(crl, Asn1Time::from_unix(next_update.naive_utc().timestamp())?.as_ptr())}; + unsafe { + X509_CRL_set_issuer_name(crl, x509_name); + }; + unsafe { + X509_CRL_set1_lastUpdate( + crl, + Asn1Time::from_unix(last_update.naive_utc().timestamp())?.as_ptr(), + ) + }; + unsafe { + X509_CRL_set1_nextUpdate( + crl, + Asn1Time::from_unix(next_update.naive_utc().timestamp())?.as_ptr(), + ) + }; for revoked_key in revoked_keys { //TODO: Add revoke reason here. if let Some(serial_number) = revoked_key.serial_number { let cert_serial = BigNum::from_slice(&decode_hex_string_to_u8(&serial_number))?; - let revoked =unsafe{X509_REVOKED_new()}; - unsafe {X509_REVOKED_set_serialNumber(revoked, Asn1Integer::from_bn(&cert_serial)?.as_ptr())}; - unsafe {X509_REVOKED_set_revocationDate(revoked, Asn1Time::from_unix(revoked_key.create_at.naive_utc().timestamp())?.as_ptr())}; - unsafe {X509_CRL_add0_revoked(crl, revoked)}; + let revoked = unsafe { X509_REVOKED_new() }; + unsafe { + X509_REVOKED_set_serialNumber( + revoked, + Asn1Integer::from_bn(&cert_serial)?.as_ptr(), + ) + }; + unsafe { + X509_REVOKED_set_revocationDate( + revoked, + Asn1Time::from_unix(revoked_key.create_at.naive_utc().timestamp())? + .as_ptr(), + ) + }; + unsafe { X509_CRL_add0_revoked(crl, revoked) }; } } - unsafe {X509_CRL_sign(crl, private_key.as_ptr(), parameter.digest_algorithm.get_real_algorithm().as_ptr())}; - let content = unsafe {X509Crl::from_ptr(crl)}; + unsafe { + X509_CRL_sign( + crl, + private_key.as_ptr(), + parameter.digest_algorithm.get_real_algorithm().as_ptr(), + ) + }; + let content = unsafe { X509Crl::from_ptr(crl) }; Ok(content.to_pem()?) } } #[cfg(test)] mod test { - use chrono::{Duration, Utc}; - use std::env; use super::*; - use secstr::SecVec; + use crate::domain::datakey::entity::KeyType; use crate::domain::datakey::entity::{KeyState, ParentKey, Visibility, X509RevokeReason}; - use crate::domain::datakey::entity::{KeyType}; use crate::domain::encryption_engine::EncryptionEngine; use crate::infra::encryption::dummy_engine::DummyEngine; + use chrono::{Duration, Utc}; + use secstr::SecVec; + use std::env; fn get_infra_config() -> HashMap { - HashMap::from([ - (INFRA_CONFIG_DOMAIN_NAME.to_string(), "test.hostname".to_string()), - ]) + HashMap::from([( + INFRA_CONFIG_DOMAIN_NAME.to_string(), + "test.hostname".to_string(), + )]) } fn get_encryption_engine() -> Box { @@ -433,16 +687,23 @@ mod test { ("locality".to_string(), "guangzhou".to_string()), ("province_name".to_string(), "guangzhou".to_string()), ("country_name".to_string(), "cn".to_string()), - ("key_type".to_string() ,"rsa".to_string()), + ("key_type".to_string(), "rsa".to_string()), ("key_length".to_string(), "2048".to_string()), ("digest_algorithm".to_string(), "sha2_256".to_string()), ("create_at".to_string(), Utc::now().to_string()), - ("expire_at".to_string(), (Utc::now() + Duration::days(365)).to_string()), + ( + "expire_at".to_string(), + (Utc::now() + Duration::days(365)).to_string(), + ), ("passphrase".to_string(), "123456".to_string()), ]) } - fn get_default_datakey(name: Option, parameter: Option>, key_type: Option) -> DataKey { + fn get_default_datakey( + name: Option, + parameter: Option>, + key_type: Option, + ) -> DataKey { let now = Utc::now(); let mut datakey = DataKey { id: 0, @@ -485,41 +746,59 @@ mod test { let infra_config = get_infra_config(); // create ca let ca_key = get_default_datakey( - Some("fake ca".to_string()), Some(parameter.clone()), Some(KeyType::X509CA)); - let sec_datakey = SecDataKey::load( - &ca_key, &dummy_engine).await.expect("load sec datakey successfully"); + Some("fake ca".to_string()), + Some(parameter.clone()), + Some(KeyType::X509CA), + ); + let sec_datakey = SecDataKey::load(&ca_key, &dummy_engine) + .await + .expect("load sec datakey successfully"); let plugin = X509Plugin::new(sec_datakey).expect("create plugin successfully"); - let ca_content = plugin.generate_keys(&KeyType::X509CA, &infra_config).expect(format!("generate ca key with no passphrase successfully").as_str()); + let ca_content = plugin + .generate_keys(&KeyType::X509CA, &infra_config) + .expect(format!("generate ca key with no passphrase successfully").as_str()); // create ica let mut ica_key = get_default_datakey( - Some("fake ica".to_string()), Some(parameter.clone()), Some(KeyType::X509CA)); - ica_key.parent_key = Some(ParentKey{ + Some("fake ica".to_string()), + Some(parameter.clone()), + Some(KeyType::X509CA), + ); + ica_key.parent_key = Some(ParentKey { name: "fake ca".to_string(), private_key: ca_content.private_key, public_key: ca_content.public_key, certificate: ca_content.certificate, attributes: ca_key.attributes.clone(), }); - let sec_datakey = SecDataKey::load( - &ica_key, &dummy_engine).await.expect("load sec datakey successfully"); + let sec_datakey = SecDataKey::load(&ica_key, &dummy_engine) + .await + .expect("load sec datakey successfully"); let plugin = X509Plugin::new(sec_datakey).expect("create plugin successfully"); - let ica_content = plugin.generate_keys(&KeyType::X509ICA, &infra_config).expect(format!("generate ica key with no passphrase successfully").as_str()); + let ica_content = plugin + .generate_keys(&KeyType::X509ICA, &infra_config) + .expect(format!("generate ica key with no passphrase successfully").as_str()); //create ee let mut ee_key = get_default_datakey( - Some("fake ee".to_string()), Some(parameter.clone()), Some(KeyType::X509CA)); - ee_key.parent_key = Some(ParentKey{ + Some("fake ee".to_string()), + Some(parameter.clone()), + Some(KeyType::X509CA), + ); + ee_key.parent_key = Some(ParentKey { name: "fake ca".to_string(), private_key: ica_content.private_key, public_key: ica_content.public_key, certificate: ica_content.certificate, attributes: ica_key.attributes.clone(), }); - let sec_datakey = SecDataKey::load( - &ica_key, &dummy_engine).await.expect("load sec datakey successfully"); + let sec_datakey = SecDataKey::load(&ica_key, &dummy_engine) + .await + .expect("load sec datakey successfully"); let plugin = X509Plugin::new(sec_datakey).expect("create plugin successfully"); - let ee_content = plugin.generate_keys(&KeyType::X509EE, &infra_config).expect(format!("generate ee key with no passphrase successfully").as_str()); + let ee_content = plugin + .generate_keys(&KeyType::X509EE, &infra_config) + .expect(format!("generate ee key with no passphrase successfully").as_str()); let sec_keys = SecDataKey { name: "".to_string(), @@ -537,9 +816,11 @@ mod test { fn test_key_type_generate_parameter() { let mut parameter = get_default_parameter(); parameter.insert("key_type".to_string(), "invalid".to_string()); - attributes_validate::(¶meter).expect_err("invalid key type"); + attributes_validate::(¶meter) + .expect_err("invalid key type"); parameter.insert("key_type".to_string(), "".to_string()); - attributes_validate::(¶meter).expect_err("invalid empty key type"); + attributes_validate::(¶meter) + .expect_err("invalid empty key type"); for key_type in all::().collect::>() { parameter.insert("key_type".to_string(), key_type.to_string()); attributes_validate::(¶meter).expect("valid key type"); @@ -549,13 +830,16 @@ mod test { #[test] fn test_key_size_generate_parameter() { let mut parameter = get_default_parameter(); - parameter.insert("key_length".to_string(), "1024".to_string()); - attributes_validate::(¶meter).expect_err("invalid key length"); + parameter.insert("key_length".to_string(), "1024".to_string()); + attributes_validate::(¶meter) + .expect_err("invalid key length"); parameter.insert("key_length".to_string(), "".to_string()); - attributes_validate::(¶meter).expect_err("invalid empty key length"); + attributes_validate::(¶meter) + .expect_err("invalid empty key length"); for key_length in X509_VALID_KEY_SIZE { parameter.insert("key_length".to_string(), key_length.to_string()); - attributes_validate::(¶meter).expect("valid key length"); + attributes_validate::(¶meter) + .expect("valid key length"); } } @@ -563,22 +847,27 @@ mod test { fn test_digest_algorithm_generate_parameter() { let mut parameter = get_default_parameter(); parameter.insert("digest_algorithm".to_string(), "1234".to_string()); - attributes_validate::(¶meter).expect_err("invalid digest algorithm"); - parameter.insert("digest_algorithm".to_string(),"".to_string()); - attributes_validate::(¶meter).expect_err("invalid empty digest algorithm"); + attributes_validate::(¶meter) + .expect_err("invalid digest algorithm"); + parameter.insert("digest_algorithm".to_string(), "".to_string()); + attributes_validate::(¶meter) + .expect_err("invalid empty digest algorithm"); for key_length in all::().collect::>() { - parameter.insert("digest_algorithm".to_string(),key_length.to_string()); - attributes_validate::(¶meter).expect("valid digest algorithm"); + parameter.insert("digest_algorithm".to_string(), key_length.to_string()); + attributes_validate::(¶meter) + .expect("valid digest algorithm"); } } #[test] fn test_create_at_generate_parameter() { let mut parameter = get_default_parameter(); - parameter.insert("create_at".to_string(),"1234".to_string()); - attributes_validate::(¶meter).expect_err("invalid create at"); - parameter.insert("create_at".to_string(),"".to_string()); - attributes_validate::(¶meter).expect_err("invalid empty create at"); + parameter.insert("create_at".to_string(), "1234".to_string()); + attributes_validate::(¶meter) + .expect_err("invalid create at"); + parameter.insert("create_at".to_string(), "".to_string()); + attributes_validate::(¶meter) + .expect_err("invalid empty create at"); parameter.insert("create_at".to_string(), Utc::now().to_string()); attributes_validate::(¶meter).expect("valid create at"); } @@ -586,13 +875,22 @@ mod test { #[test] fn test_expire_at_generate_parameter() { let mut parameter = get_default_parameter(); - parameter.insert("expire_at".to_string(),"1234".to_string()); - attributes_validate::(¶meter).expect_err("invalid expire at"); + parameter.insert("expire_at".to_string(), "1234".to_string()); + attributes_validate::(¶meter) + .expect_err("invalid expire at"); parameter.insert("expire_at".to_string(), "".to_string()); - attributes_validate::(¶meter).expect_err("invalid empty expire at"); - parameter.insert("expire_at".to_string(),(Utc::now() - Duration::days(1)).to_string()); - attributes_validate::(¶meter).expect_err("expire at expired"); - parameter.insert("expire_at".to_string(), (Utc::now() + Duration::minutes(1)).to_string()); + attributes_validate::(¶meter) + .expect_err("invalid empty expire at"); + parameter.insert( + "expire_at".to_string(), + (Utc::now() - Duration::days(1)).to_string(), + ); + attributes_validate::(¶meter) + .expect_err("expire at expired"); + parameter.insert( + "expire_at".to_string(), + (Utc::now() + Duration::minutes(1)).to_string(), + ); attributes_validate::(¶meter).expect("valid expire at"); } @@ -605,10 +903,15 @@ mod test { for hash in all::().collect::>() { parameter.insert("digest_algorithm".to_string(), hash.to_string()); let sec_datakey = SecDataKey::load( - &get_default_datakey( - None, Some(parameter.clone()), Some(KeyType::X509CA)), &dummy_engine).await.expect("load sec datakey successfully"); + &get_default_datakey(None, Some(parameter.clone()), Some(KeyType::X509CA)), + &dummy_engine, + ) + .await + .expect("load sec datakey successfully"); let plugin = X509Plugin::new(sec_datakey).expect("create plugin successfully"); - plugin.generate_keys(&KeyType::X509CA, &infra_config).expect(format!("generate ca key with digest {} successfully", hash).as_str()); + plugin + .generate_keys(&KeyType::X509CA, &infra_config) + .expect(format!("generate ca key with digest {} successfully", hash).as_str()); } } @@ -620,10 +923,17 @@ mod test { for key_size in X509_VALID_KEY_SIZE { parameter.insert("key_size".to_string(), key_size.to_string()); let sec_datakey = SecDataKey::load( - &get_default_datakey( - None, Some(parameter.clone()), Some(KeyType::X509CA)), &dummy_engine).await.expect("load sec datakey successfully"); + &get_default_datakey(None, Some(parameter.clone()), Some(KeyType::X509CA)), + &dummy_engine, + ) + .await + .expect("load sec datakey successfully"); let plugin = X509Plugin::new(sec_datakey).expect("create plugin successfully"); - plugin.generate_keys(&KeyType::X509CA, &infra_config).expect(format!("generate ca key with key size {} successfully", key_size).as_str()); + plugin + .generate_keys(&KeyType::X509CA, &infra_config) + .expect( + format!("generate ca key with key size {} successfully", key_size).as_str(), + ); } } @@ -635,10 +945,17 @@ mod test { for key_type in all::().collect::>() { parameter.insert("key_type".to_string(), key_type.to_string()); let sec_datakey = SecDataKey::load( - &get_default_datakey( - None, Some(parameter.clone()), Some(KeyType::X509CA)), &dummy_engine).await.expect("load sec datakey successfully"); + &get_default_datakey(None, Some(parameter.clone()), Some(KeyType::X509CA)), + &dummy_engine, + ) + .await + .expect("load sec datakey successfully"); let plugin = X509Plugin::new(sec_datakey).expect("create plugin successfully"); - plugin.generate_keys(&KeyType::X509CA, &infra_config).expect(format!("generate ca key with key type {} successfully", key_type).as_str()); + plugin + .generate_keys(&KeyType::X509CA, &infra_config) + .expect( + format!("generate ca key with key type {} successfully", key_type).as_str(), + ); } } @@ -649,17 +966,27 @@ mod test { let infra_config = get_infra_config(); let sec_datakey = SecDataKey::load( - &get_default_datakey( - None, Some(parameter.clone()), Some(KeyType::X509CA)), &dummy_engine).await.expect("load sec datakey successfully"); + &get_default_datakey(None, Some(parameter.clone()), Some(KeyType::X509CA)), + &dummy_engine, + ) + .await + .expect("load sec datakey successfully"); let plugin = X509Plugin::new(sec_datakey).expect("create plugin successfully"); - plugin.generate_keys(&KeyType::X509CA, &infra_config).expect(format!("generate ca key with no passphrase successfully").as_str()); + plugin + .generate_keys(&KeyType::X509CA, &infra_config) + .expect(format!("generate ca key with no passphrase successfully").as_str()); parameter.insert("passphrase".to_string(), "".to_string()); let sec_datakey = SecDataKey::load( - &get_default_datakey( - None, Some(parameter.clone()), Some(KeyType::X509CA)), &dummy_engine).await.expect("load sec datakey successfully"); + &get_default_datakey(None, Some(parameter.clone()), Some(KeyType::X509CA)), + &dummy_engine, + ) + .await + .expect("load sec datakey successfully"); let plugin = X509Plugin::new(sec_datakey).expect("create plugin successfully"); - plugin.generate_keys(&KeyType::X509CA, &infra_config).expect(format!("generate ca key with passphrase successfully").as_str()); + plugin + .generate_keys(&KeyType::X509CA, &infra_config) + .expect(format!("generate ca key with passphrase successfully").as_str()); } #[test] @@ -726,7 +1053,10 @@ X5BboR/QJakEK+H+EUQAiDs= datakey.private_key = private_key.as_bytes().to_vec(); X509Plugin::validate_and_update(&mut datakey).expect("validate and update should work"); assert_eq!("2123-04-29 09:48:00 UTC", datakey.expire_at.to_string()); - assert_eq!("C9345187DFA0BFB6DCBCC4827BBEA7312E43754B", datakey.fingerprint); + assert_eq!( + "C9345187DFA0BFB6DCBCC4827BBEA7312E43754B", + datakey.fingerprint + ); } #[tokio::test] @@ -734,7 +1064,9 @@ X5BboR/QJakEK+H+EUQAiDs= let parameter = get_default_parameter(); let content = "hello world".as_bytes(); let instance = get_default_plugin().await; - let _signature = instance.sign(content.to_vec(), parameter).expect("sign successfully"); + let _signature = instance + .sign(content.to_vec(), parameter) + .expect("sign successfully"); } #[tokio::test] @@ -744,41 +1076,77 @@ X5BboR/QJakEK+H+EUQAiDs= let infra_config = get_infra_config(); // create ca let mut ca_key = get_default_datakey( - Some("fake ca".to_string()), Some(parameter.clone()), Some(KeyType::X509CA)); - let sec_datakey = SecDataKey::load( - &ca_key, &dummy_engine).await.expect("load sec datakey successfully"); + Some("fake ca".to_string()), + Some(parameter.clone()), + Some(KeyType::X509CA), + ); + let sec_datakey = SecDataKey::load(&ca_key, &dummy_engine) + .await + .expect("load sec datakey successfully"); let plugin = X509Plugin::new(sec_datakey).expect("create plugin successfully"); - let ca_content = plugin.generate_keys(&KeyType::X509CA, &infra_config).expect(format!("generate ca key with no passphrase successfully").as_str()); + let ca_content = plugin + .generate_keys(&KeyType::X509CA, &infra_config) + .expect(format!("generate ca key with no passphrase successfully").as_str()); ca_key.private_key = ca_content.private_key; ca_key.public_key = ca_content.public_key; ca_key.certificate = ca_content.certificate; ca_key.serial_number = ca_content.serial_number; ca_key.fingerprint = ca_content.fingerprint; - let crl_sec_datakey = SecDataKey::load( - &ca_key, &dummy_engine).await.expect("load sec datakey successfully"); + let crl_sec_datakey = SecDataKey::load(&ca_key, &dummy_engine) + .await + .expect("load sec datakey successfully"); let plugin = X509Plugin::new(crl_sec_datakey).expect("create plugin successfully"); let revoke_time = Utc::now(); let last_update = Utc::now() + Duration::days(1); let next_update = Utc::now() + Duration::days(2); - let serial_number = X509Plugin::generate_serial_number().expect("generate serial number successfully"); - let revoked_keys = RevokedKey{ + let serial_number = + X509Plugin::generate_serial_number().expect("generate serial number successfully"); + let revoked_keys = RevokedKey { id: 0, key_id: 0, ca_id: 0, reason: X509RevokeReason::Unspecified, create_at: revoke_time.clone(), - serial_number: Some(encode_u8_to_hex_string(&serial_number.to_vec())) + serial_number: Some(encode_u8_to_hex_string(&serial_number.to_vec())), }; //generate crl - let content = plugin.generate_crl_content(vec![revoked_keys], last_update.clone(), next_update.clone()).expect("generate crl successfully"); + let content = plugin + .generate_crl_content(vec![revoked_keys], last_update.clone(), next_update.clone()) + .expect("generate crl successfully"); let crl = X509Crl::from_pem(&content).expect("load generated crl successfully"); - assert_eq!(crl.last_update()==Asn1Time::from_unix(last_update.naive_utc().timestamp()).expect("convert to asn1 time successfully"), true); - assert_eq!(crl.next_update().expect("next update is set")==Asn1Time::from_unix(next_update.naive_utc().timestamp()).expect("convert to asn1 time successfully"), true); + assert_eq!( + crl.last_update() + == Asn1Time::from_unix(last_update.naive_utc().timestamp()) + .expect("convert to asn1 time successfully"), + true + ); + assert_eq!( + crl.next_update().expect("next update is set") + == Asn1Time::from_unix(next_update.naive_utc().timestamp()) + .expect("convert to asn1 time successfully"), + true + ); assert_eq!(crl.get_revoked().is_some(), true); - let revoked = crl.get_revoked().expect("revoke stack is not empty").get(0).expect("first revoke is not empty"); - assert_eq!(revoked.serial_number().to_owned().expect("convert to asn1 number work") == Asn1Integer::from_bn(&serial_number).expect("convert from bn number should work"), true); - assert_eq!(revoked.revocation_date().to_owned()==Asn1Time::from_unix(revoke_time.naive_utc().timestamp()).expect("convert to asn1 time successfully"), true); - + let revoked = crl + .get_revoked() + .expect("revoke stack is not empty") + .get(0) + .expect("first revoke is not empty"); + assert_eq!( + revoked + .serial_number() + .to_owned() + .expect("convert to asn1 number work") + == Asn1Integer::from_bn(&serial_number) + .expect("convert from bn number should work"), + true + ); + assert_eq!( + revoked.revocation_date().to_owned() + == Asn1Time::from_unix(revoke_time.naive_utc().timestamp()) + .expect("convert to asn1 time successfully"), + true + ); } #[tokio::test] diff --git a/src/presentation/handler/control/datakey_handler.rs b/src/presentation/handler/control/datakey_handler.rs index 6c1d56e..f9fd467 100644 --- a/src/presentation/handler/control/datakey_handler.rs +++ b/src/presentation/handler/control/datakey_handler.rs @@ -14,19 +14,21 @@ * */ +use actix_web::{web, HttpResponse, Responder, Result, Scope}; use std::str::FromStr; -use actix_web::{ - HttpResponse, Responder, Result, web, Scope -}; - -use crate::presentation::handler::control::model::datakey::dto::{CertificateContent, CreateDataKeyDTO, CRLContent, DataKeyDTO, ImportDataKeyDTO, ListKeyQuery, NameIdenticalQuery, PagedDatakeyDTO, PublicKeyContent, RevokeCertificateDTO}; -use crate::util::error::Error; -use validator::Validate; +use super::model::user::dto::UserIdentity; use crate::application::datakey::KeyService; -use crate::domain::datakey::entity::{DataKey, DatakeyPaginationQuery, KeyType, Visibility, X509RevokeReason}; +use crate::domain::datakey::entity::{ + DataKey, DatakeyPaginationQuery, KeyType, Visibility, X509RevokeReason, +}; +use crate::presentation::handler::control::model::datakey::dto::{ + CRLContent, CertificateContent, CreateDataKeyDTO, DataKeyDTO, ImportDataKeyDTO, ListKeyQuery, + NameIdenticalQuery, PagedDatakeyDTO, PublicKeyContent, RevokeCertificateDTO, +}; +use crate::util::error::Error; use crate::util::key::get_datakey_full_name; -use super::model::user::dto::UserIdentity; +use validator::Validate; /// Create new key /// @@ -116,10 +118,16 @@ use super::model::user::dto::UserIdentity; (status = 500, description = "Server internal error", body = ErrorMessage) ) )] -async fn create_data_key(user: UserIdentity, key_service: web::Data, datakey: web::Json,) -> Result { +async fn create_data_key( + user: UserIdentity, + key_service: web::Data, + datakey: web::Json, +) -> Result { datakey.validate()?; let mut key = DataKey::create_from(datakey.0, user.clone())?; - Ok(HttpResponse::Created().json(DataKeyDTO::try_from(key_service.into_inner().create(user, &mut key).await?)?)) + Ok(HttpResponse::Created().json(DataKeyDTO::try_from( + key_service.into_inner().create(user, &mut key).await?, + )?)) } /// Get all available keys from database. @@ -144,11 +152,18 @@ async fn create_data_key(user: UserIdentity, key_service: web::Data, key: web::Query) -> Result { +async fn list_data_key( + user: UserIdentity, + key_service: web::Data, + key: web::Query, +) -> Result { key.validate()?; //test visibility matched. Visibility::from_parameter(key.visibility.clone())?; - let keys = key_service.into_inner().get_all(user.id, DatakeyPaginationQuery::from(key.into_inner())).await?; + let keys = key_service + .into_inner() + .get_all(user.id, DatakeyPaginationQuery::from(key.into_inner())) + .await?; Ok(HttpResponse::Ok().json(PagedDatakeyDTO::try_from(keys)?)) } @@ -177,8 +192,15 @@ async fn list_data_key(user: UserIdentity, key_service: web::Data, id_or_name: web::Path) -> Result { - let key = key_service.into_inner().get_one(Some(user), id_or_name.into_inner()).await?; +async fn show_data_key( + user: UserIdentity, + key_service: web::Data, + id_or_name: web::Path, +) -> Result { + let key = key_service + .into_inner() + .get_one(Some(user), id_or_name.into_inner()) + .await?; Ok(HttpResponse::Ok().json(DataKeyDTO::try_from(key)?)) } @@ -208,8 +230,15 @@ async fn show_data_key(user: UserIdentity, key_service: web::Data, id_or_name: web::Path) -> Result { - key_service.into_inner().request_delete(user, id_or_name.into_inner()).await?; +async fn delete_data_key( + user: UserIdentity, + key_service: web::Data, + id_or_name: web::Path, +) -> Result { + key_service + .into_inner() + .request_delete(user, id_or_name.into_inner()) + .await?; Ok(HttpResponse::Ok()) } @@ -239,8 +268,15 @@ async fn delete_data_key(user: UserIdentity, key_service: web::Data, id_or_name: web::Path) -> Result { - key_service.into_inner().cancel_delete(user, id_or_name.into_inner()).await?; +async fn cancel_delete_data_key( + user: UserIdentity, + key_service: web::Data, + id_or_name: web::Path, +) -> Result { + key_service + .into_inner() + .cancel_delete(user, id_or_name.into_inner()) + .await?; Ok(HttpResponse::Ok()) } @@ -271,8 +307,20 @@ async fn cancel_delete_data_key(user: UserIdentity, key_service: web::Data, id_or_name: web::Path, reason: web::Json) -> Result { - key_service.into_inner().request_revoke(user, id_or_name.into_inner(), X509RevokeReason::from_str(&reason.reason)?).await?; +async fn revoke_data_key( + user: UserIdentity, + key_service: web::Data, + id_or_name: web::Path, + reason: web::Json, +) -> Result { + key_service + .into_inner() + .request_revoke( + user, + id_or_name.into_inner(), + X509RevokeReason::from_str(&reason.reason)?, + ) + .await?; Ok(HttpResponse::Ok()) } @@ -302,8 +350,15 @@ async fn revoke_data_key(user: UserIdentity, key_service: web::Data, id_or_name: web::Path) -> Result { - key_service.into_inner().cancel_revoke(user, id_or_name.into_inner()).await?; +async fn cancel_revoke_data_key( + user: UserIdentity, + key_service: web::Data, + id_or_name: web::Path, +) -> Result { + key_service + .into_inner() + .cancel_revoke(user, id_or_name.into_inner()) + .await?; Ok(HttpResponse::Ok()) } @@ -330,12 +385,20 @@ async fn cancel_revoke_data_key(user: UserIdentity, key_service: web::Data, key_service: web::Data, id_or_name: web::Path) -> Result { - let data_key = key_service.export_one(user, id_or_name.into_inner()).await?; +async fn export_public_key( + user: Option, + key_service: web::Data, + id_or_name: web::Path, +) -> Result { + let data_key = key_service + .export_one(user, id_or_name.into_inner()) + .await?; if data_key.key_type != KeyType::OpenPGP { - return Ok(HttpResponse::Forbidden().finish()) + return Ok(HttpResponse::Forbidden().finish()); } - Ok(HttpResponse::Ok().content_type("text/plain").body(PublicKeyContent::try_from(data_key)?.content)) + Ok(HttpResponse::Ok() + .content_type("text/plain") + .body(PublicKeyContent::try_from(data_key)?.content)) } /// Get certificate content of specific key by id or name from database @@ -361,12 +424,20 @@ async fn export_public_key(user: Option, key_service: web::Data, key_service: web::Data, id_or_name: web::Path) -> Result { - let data_key = key_service.export_one(user, id_or_name.into_inner()).await?; +async fn export_certificate( + user: Option, + key_service: web::Data, + id_or_name: web::Path, +) -> Result { + let data_key = key_service + .export_one(user, id_or_name.into_inner()) + .await?; if data_key.key_type == KeyType::OpenPGP { - return Ok(HttpResponse::Forbidden().finish()) + return Ok(HttpResponse::Forbidden().finish()); } - Ok(HttpResponse::Ok().content_type("text/plain").body(CertificateContent::try_from(data_key)?.content)) + Ok(HttpResponse::Ok() + .content_type("text/plain") + .body(CertificateContent::try_from(data_key)?.content)) } /// Get Client Revoke List content of specific key(cert) by id or name from database @@ -392,13 +463,20 @@ async fn export_certificate(user: Option, key_service: web::Data, key_service: web::Data, id_or_name: web::Path) -> Result { +async fn export_crl( + user: Option, + key_service: web::Data, + id_or_name: web::Path, +) -> Result { //note: we could not get any crl content by a openpgp id. - let crl_content = key_service.export_cert_crl(user, id_or_name.into_inner()).await?; - Ok(HttpResponse::Ok().content_type("text/plain").body(CRLContent::try_from(crl_content)?.content)) + let crl_content = key_service + .export_cert_crl(user, id_or_name.into_inner()) + .await?; + Ok(HttpResponse::Ok() + .content_type("text/plain") + .body(CRLContent::try_from(crl_content)?.content)) } - /// Enable specific key by id or name from database /// /// ## Example @@ -424,8 +502,14 @@ async fn export_crl(user: Option, key_service: web::Data, id_or_name: web::Path) -> Result { - key_service.enable(Some(user), id_or_name.into_inner()).await?; +async fn enable_data_key( + user: UserIdentity, + key_service: web::Data, + id_or_name: web::Path, +) -> Result { + key_service + .enable(Some(user), id_or_name.into_inner()) + .await?; Ok(HttpResponse::Ok()) } @@ -454,8 +538,14 @@ async fn enable_data_key(user: UserIdentity, key_service: web::Data, id_or_name: web::Path) -> Result { - key_service.disable(Some(user), id_or_name.into_inner()).await?; +async fn disable_data_key( + user: UserIdentity, + key_service: web::Data, + id_or_name: web::Path, +) -> Result { + key_service + .disable(Some(user), id_or_name.into_inner()) + .await?; Ok(HttpResponse::Ok()) } @@ -483,11 +573,19 @@ async fn disable_data_key(user: UserIdentity, key_service: web::Data, name_exist: web::Query,) -> Result { +async fn key_name_identical( + user: UserIdentity, + key_service: web::Data, + name_exist: web::Query, +) -> Result { name_exist.validate()?; let visibility = Visibility::from_parameter(name_exist.visibility.clone())?; let key_name = get_datakey_full_name(&name_exist.name, &user.email, &visibility)?; - match key_service.into_inner().get_raw_key_by_name(&key_name).await { + match key_service + .into_inner() + .get_raw_key_by_name(&key_name) + .await + { Ok(_) => Ok(HttpResponse::Conflict()), Err(_) => Ok(HttpResponse::Ok()), } @@ -558,29 +656,53 @@ async fn key_name_identical(user: UserIdentity, key_service: web::Data, datakey: web::Json) -> Result { +async fn import_data_key( + user: UserIdentity, + key_service: web::Data, + datakey: web::Json, +) -> Result { datakey.validate()?; let mut key = DataKey::import_from(datakey.0, user)?; - Ok(HttpResponse::Created().json(DataKeyDTO::try_from(key_service.into_inner().import(&mut key).await?)?)) + Ok(HttpResponse::Created().json(DataKeyDTO::try_from( + key_service.into_inner().import(&mut key).await?, + )?)) } - pub fn get_scope() -> Scope { web::scope("/keys") .service( web::resource("/") .route(web::get().to(list_data_key)) - .route(web::post().to(create_data_key))) - .service( web::resource("/import").route(web::post().to(import_data_key))) - .service( web::resource("/name_identical").route(web::head().to(key_name_identical))) - .service( web::resource("/{id_or_name}").route(web::get().to(show_data_key))) - .service( web::resource("/{id_or_name}/public_key").route(web::get().to(export_public_key))) - .service( web::resource("/{id_or_name}/certificate").route(web::get().to(export_certificate))) - .service( web::resource("/{id_or_name}/crl").route(web::get().to(export_crl))) - .service( web::resource("/{id_or_name}/actions/enable").route(web::post().to(enable_data_key))) - .service( web::resource("/{id_or_name}/actions/disable").route(web::post().to(disable_data_key))) - .service( web::resource("/{id_or_name}/actions/request_delete").route(web::post().to(delete_data_key))) - .service( web::resource("/{id_or_name}/actions/cancel_delete").route(web::post().to(cancel_delete_data_key))) - .service( web::resource("/{id_or_name}/actions/request_revoke").route(web::post().to(revoke_data_key))) - .service( web::resource("/{id_or_name}/actions/cancel_revoke").route(web::post().to(cancel_revoke_data_key))) + .route(web::post().to(create_data_key)), + ) + .service(web::resource("/import").route(web::post().to(import_data_key))) + .service(web::resource("/name_identical").route(web::head().to(key_name_identical))) + .service(web::resource("/{id_or_name}").route(web::get().to(show_data_key))) + .service(web::resource("/{id_or_name}/public_key").route(web::get().to(export_public_key))) + .service( + web::resource("/{id_or_name}/certificate").route(web::get().to(export_certificate)), + ) + .service(web::resource("/{id_or_name}/crl").route(web::get().to(export_crl))) + .service( + web::resource("/{id_or_name}/actions/enable").route(web::post().to(enable_data_key)), + ) + .service( + web::resource("/{id_or_name}/actions/disable").route(web::post().to(disable_data_key)), + ) + .service( + web::resource("/{id_or_name}/actions/request_delete") + .route(web::post().to(delete_data_key)), + ) + .service( + web::resource("/{id_or_name}/actions/cancel_delete") + .route(web::post().to(cancel_delete_data_key)), + ) + .service( + web::resource("/{id_or_name}/actions/request_revoke") + .route(web::post().to(revoke_data_key)), + ) + .service( + web::resource("/{id_or_name}/actions/cancel_revoke") + .route(web::post().to(cancel_revoke_data_key)), + ) } diff --git a/src/presentation/handler/control/health_handler.rs b/src/presentation/handler/control/health_handler.rs index de311e2..2a43c00 100644 --- a/src/presentation/handler/control/health_handler.rs +++ b/src/presentation/handler/control/health_handler.rs @@ -14,8 +14,8 @@ * */ -use actix_web::{HttpResponse, Responder, Result, web, Scope}; use crate::util::error::Error; +use actix_web::{web, HttpResponse, Responder, Result, Scope}; use crate::application::user::UserService; @@ -39,6 +39,5 @@ async fn health(_user_service: web::Data) -> Result Scope { - web::scope("/health") - .service(web::resource("/").route(web::get().to(health))) -} \ No newline at end of file + web::scope("/health").service(web::resource("/").route(web::get().to(health))) +} diff --git a/src/presentation/handler/control/mod.rs b/src/presentation/handler/control/mod.rs index ee5b877..f6eae1b 100644 --- a/src/presentation/handler/control/mod.rs +++ b/src/presentation/handler/control/mod.rs @@ -1,4 +1,4 @@ pub mod datakey_handler; -pub mod user_handler; pub mod health_handler; -pub mod model; \ No newline at end of file +pub mod model; +pub mod user_handler; diff --git a/src/presentation/handler/control/model/datakey/dto.rs b/src/presentation/handler/control/model/datakey/dto.rs index ba32aab..42131a0 100644 --- a/src/presentation/handler/control/model/datakey/dto.rs +++ b/src/presentation/handler/control/model/datakey/dto.rs @@ -14,30 +14,32 @@ * */ -use crate::domain::datakey::entity::{DataKey, DatakeyPaginationQuery, KeyState, PagedDatakey, Visibility, X509CRL}; use crate::domain::datakey::entity::KeyType; +use crate::domain::datakey::entity::{ + DataKey, DatakeyPaginationQuery, KeyState, PagedDatakey, Visibility, X509CRL, +}; use crate::util::error::Result; +use crate::util::key::{get_datakey_full_name, sorted_map}; use chrono::{DateTime, Utc}; use std::str::FromStr; -use crate::util::key::{get_datakey_full_name, sorted_map}; -use validator::{Validate, ValidationError}; -use std::collections::HashMap; +use crate::presentation::handler::control::model::user::dto::UserIdentity; use crate::util::error::Error; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use utoipa::{IntoParams, ToSchema}; -use crate::presentation::handler::control::model::user::dto::UserIdentity; +use validator::{Validate, ValidationError}; #[derive(Deserialize, Serialize, ToSchema)] pub struct PublicKeyContent { - pub(crate) content: String, + pub(crate) content: String, } impl TryFrom for PublicKeyContent { type Error = Error; fn try_from(value: DataKey) -> std::result::Result { - Ok(PublicKeyContent{ + Ok(PublicKeyContent { content: String::from_utf8_lossy(&value.public_key).to_string(), }) } @@ -52,7 +54,7 @@ impl TryFrom for CertificateContent { type Error = Error; fn try_from(value: DataKey) -> std::result::Result { - Ok(CertificateContent{ + Ok(CertificateContent { content: String::from_utf8_lossy(&value.certificate).to_string(), }) } @@ -67,7 +69,7 @@ impl TryFrom for CRLContent { type Error = Error; fn try_from(value: X509CRL) -> std::result::Result { - Ok(CRLContent{ + Ok(CRLContent { content: String::from_utf8_lossy(&value.data).to_string(), }) } @@ -99,7 +101,6 @@ pub struct ListKeyQuery { /// the request page index, starts from 1, max 1000 #[validate(range(min = 1, max = 1000))] pub page_number: u64, - } impl From for DatakeyPaginationQuery { @@ -110,12 +111,11 @@ impl From for DatakeyPaginationQuery { name: value.name, description: value.description, key_type: value.key_type, - visibility: value.visibility + visibility: value.visibility, } } } - #[derive(Debug, Validate, Deserialize, ToSchema)] pub struct CreateDataKeyDTO { /// Key Name, should be identical, length between 4 and 256, not contains any colon symbol. @@ -215,7 +215,7 @@ pub struct PagedMetaDTO { #[derive(Debug, Serialize, ToSchema)] pub struct PagedDatakeyDTO { pub data: Vec, - pub meta: PagedMetaDTO + pub meta: PagedMetaDTO, } impl TryFrom for PagedDatakeyDTO { @@ -235,7 +235,6 @@ impl TryFrom for PagedDatakeyDTO { } } - fn validate_utc_time(expire: &str) -> std::result::Result<(), ValidationError> { if expire.parse::>().is_err() { return Err(ValidationError::new("failed to parse time string to utc")); @@ -245,24 +244,15 @@ fn validate_utc_time(expire: &str) -> std::result::Result<(), ValidationError> { fn validate_key_visibility(visibility: &str) -> std::result::Result<(), ValidationError> { match Visibility::from_str(visibility) { - Ok(_) => { - Ok(()) - } - Err(_) => { - Err(ValidationError::new("unsupported key visibility")) - } + Ok(_) => Ok(()), + Err(_) => Err(ValidationError::new("unsupported key visibility")), } } - fn validate_key_type(key_type: &str) -> std::result::Result<(), ValidationError> { match KeyType::from_str(key_type) { - Ok(_) => { - Ok(()) - } - Err(_) => { - Err(ValidationError::new("unsupported key type")) - } + Ok(_) => Ok(()), + Err(_) => Err(ValidationError::new("unsupported key type")), } } @@ -378,7 +368,7 @@ mod tests { use super::*; #[test] fn test_public_key_content_from_datakey() { - let key1 = DataKey{ + let key1 = DataKey { id: 1, name: "Test Key".to_string(), description: "".to_string(), @@ -389,9 +379,9 @@ mod tests { parent_id: Some(2), fingerprint: "".to_string(), serial_number: None, - private_key: vec![7,8,9,10], - public_key: vec![4,5,6], - certificate: vec![1,2,3], + private_key: vec![7, 8, 9, 10], + public_key: vec![4, 5, 6], + certificate: vec![1, 2, 3], create_at: Default::default(), expire_at: Default::default(), key_state: KeyState::Disabled, @@ -406,7 +396,7 @@ mod tests { #[test] fn test_certificate_content_from_datakey() { - let key1 = DataKey{ + let key1 = DataKey { id: 1, name: "Test Key".to_string(), description: "".to_string(), @@ -417,9 +407,9 @@ mod tests { parent_id: Some(2), fingerprint: "".to_string(), serial_number: None, - private_key: vec![7,8,9,10], - public_key: vec![4,5,6], - certificate: vec![1,2,3], + private_key: vec![7, 8, 9, 10], + public_key: vec![4, 5, 6], + certificate: vec![1, 2, 3], create_at: Default::default(), expire_at: Default::default(), key_state: KeyState::Disabled, @@ -434,7 +424,7 @@ mod tests { #[test] fn test_datakey_dto_from_datakey() { - let key1 = DataKey{ + let key1 = DataKey { id: 1, name: "Test Key".to_string(), description: "".to_string(), @@ -445,9 +435,9 @@ mod tests { parent_id: Some(2), fingerprint: "".to_string(), serial_number: None, - private_key: vec![7,8,9,10], - public_key: vec![4,5,6], - certificate: vec![1,2,3], + private_key: vec![7, 8, 9, 10], + public_key: vec![4, 5, 6], + certificate: vec![1, 2, 3], create_at: Default::default(), expire_at: Default::default(), key_state: KeyState::Disabled, @@ -461,10 +451,10 @@ mod tests { #[test] fn test_crl_content_from_crl_model() { - let crl = X509CRL{ + let crl = X509CRL { id: 1, ca_id: 2, - data: vec![1,2,3], + data: vec![1, 2, 3], create_at: Default::default(), update_at: Default::default(), }; @@ -474,7 +464,7 @@ mod tests { #[test] fn test_list_key_query() { - let page_query_invalid1 = ListKeyQuery{ + let page_query_invalid1 = ListKeyQuery { key_type: Some("x509ee".to_string()), visibility: Some("public".to_string()), name: Some("test".to_string()), @@ -483,7 +473,7 @@ mod tests { page_number: 1, }; assert!(page_query_invalid1.validate().is_err()); - let page_query_invalid2 = ListKeyQuery{ + let page_query_invalid2 = ListKeyQuery { key_type: Some("x509ee".to_string()), visibility: Some("public".to_string()), name: Some("test".to_string()), @@ -492,7 +482,7 @@ mod tests { page_number: 1, }; assert!(page_query_invalid2.validate().is_err()); - let page_query_invalid3 = ListKeyQuery{ + let page_query_invalid3 = ListKeyQuery { key_type: Some("x509ee".to_string()), visibility: Some("public".to_string()), name: Some("test".to_string()), @@ -501,7 +491,7 @@ mod tests { page_number: 0, }; assert!(page_query_invalid3.validate().is_err()); - let page_query_invalid4 = ListKeyQuery{ + let page_query_invalid4 = ListKeyQuery { key_type: Some("x509ee".to_string()), visibility: Some("public".to_string()), name: Some("test".to_string()), @@ -510,7 +500,7 @@ mod tests { page_number: 1001, }; assert!(page_query_invalid4.validate().is_err()); - let query = ListKeyQuery{ + let query = ListKeyQuery { key_type: Some("x509ee".to_string()), visibility: Some("public".to_string()), name: Some("test".to_string()), @@ -529,7 +519,7 @@ mod tests { #[test] fn test_create_datakey_dto() { - let invalid_name1 = CreateDataKeyDTO{ + let invalid_name1 = CreateDataKeyDTO { name: "Tes".to_string(), description: "".to_string(), visibility: Some("public".to_string()), @@ -539,11 +529,12 @@ mod tests { expire_at: Default::default(), }; assert!(invalid_name1.validate().is_err()); - let invalid_name2 = CreateDataKeyDTO{ + let invalid_name2 = CreateDataKeyDTO { name: "1234567890123456789012345678901234567890123456789012345678901234\ 567890123456789012345678901234567890123456789012345678901234567890123456\ 7890123456789012345678901234567890123456789012345678901234567890123456789\ - 012345678901234567890123456789012345678901234567890".to_string(), + 012345678901234567890123456789012345678901234567890" + .to_string(), description: "".to_string(), visibility: Some("public".to_string()), attributes: HashMap::new(), @@ -552,12 +543,13 @@ mod tests { expire_at: Default::default(), }; assert!(invalid_name2.validate().is_err()); - let invalid_desc1 = CreateDataKeyDTO{ + let invalid_desc1 = CreateDataKeyDTO { name: "Test".to_string(), description: "1234567890123456789012345678901234567890123456789012345678901234\ 567890123456789012345678901234567890123456789012345678901234567890123456\ 7890123456789012345678901234567890123456789012345678901234567890123456789\ - 012345678901234567890123456789012345678901234567890".to_string(), + 012345678901234567890123456789012345678901234567890" + .to_string(), visibility: Some("public".to_string()), attributes: HashMap::new(), key_type: "pgp".to_string(), @@ -565,7 +557,7 @@ mod tests { expire_at: Default::default(), }; assert!(invalid_desc1.validate().is_err()); - let invalid_visibility = CreateDataKeyDTO{ + let invalid_visibility = CreateDataKeyDTO { name: "Test".to_string(), description: "test descr".to_string(), visibility: Some("123".to_string()), @@ -576,7 +568,7 @@ mod tests { }; assert!(invalid_visibility.validate().is_err()); - let invalid_type = CreateDataKeyDTO{ + let invalid_type = CreateDataKeyDTO { name: "Test".to_string(), description: "test descr".to_string(), visibility: Some("public".to_string()), @@ -587,7 +579,7 @@ mod tests { }; assert!(invalid_type.validate().is_err()); - let invalid_expire = CreateDataKeyDTO{ + let invalid_expire = CreateDataKeyDTO { name: "Test".to_string(), description: "test descr".to_string(), visibility: Some("public".to_string()), @@ -598,7 +590,7 @@ mod tests { }; assert!(invalid_expire.validate().is_err()); - let dto = CreateDataKeyDTO{ + let dto = CreateDataKeyDTO { name: "Test".to_string(), description: "test descr".to_string(), visibility: Some("public".to_string()), @@ -607,7 +599,7 @@ mod tests { parent_id: Some(2), expire_at: Utc::now().to_string(), }; - let identity = UserIdentity{ + let identity = UserIdentity { email: "email1".to_string(), id: 1, csrf_generation_token: None, @@ -622,7 +614,7 @@ mod tests { #[test] fn test_import_datakey_dto() { - let invalid_name1 = ImportDataKeyDTO{ + let invalid_name1 = ImportDataKeyDTO { name: "Tes".to_string(), description: "".to_string(), visibility: Some("public".to_string()), @@ -633,11 +625,12 @@ mod tests { private_key: "1234".to_string(), }; assert!(invalid_name1.validate().is_err()); - let invalid_name2 = ImportDataKeyDTO{ + let invalid_name2 = ImportDataKeyDTO { name: "1234567890123456789012345678901234567890123456789012345678901234\ 567890123456789012345678901234567890123456789012345678901234567890123456\ 7890123456789012345678901234567890123456789012345678901234567890123456789\ - 012345678901234567890123456789012345678901234567890".to_string(), + 012345678901234567890123456789012345678901234567890" + .to_string(), description: "".to_string(), visibility: Some("public".to_string()), attributes: HashMap::new(), @@ -647,12 +640,13 @@ mod tests { private_key: "1234".to_string(), }; assert!(invalid_name2.validate().is_err()); - let invalid_desc1 = ImportDataKeyDTO{ + let invalid_desc1 = ImportDataKeyDTO { name: "Test".to_string(), description: "1234567890123456789012345678901234567890123456789012345678901234\ 567890123456789012345678901234567890123456789012345678901234567890123456\ 7890123456789012345678901234567890123456789012345678901234567890123456789\ - 012345678901234567890123456789012345678901234567890".to_string(), + 012345678901234567890123456789012345678901234567890" + .to_string(), visibility: Some("public".to_string()), attributes: HashMap::new(), key_type: "pgp".to_string(), @@ -661,7 +655,7 @@ mod tests { private_key: "1234".to_string(), }; assert!(invalid_desc1.validate().is_err()); - let invalid_visibility = ImportDataKeyDTO{ + let invalid_visibility = ImportDataKeyDTO { name: "Test".to_string(), description: "test descr".to_string(), visibility: Some("123".to_string()), @@ -673,7 +667,7 @@ mod tests { }; assert!(invalid_visibility.validate().is_err()); - let invalid_type = ImportDataKeyDTO{ + let invalid_type = ImportDataKeyDTO { name: "Test".to_string(), description: "test descr".to_string(), visibility: Some("public".to_string()), @@ -685,7 +679,7 @@ mod tests { }; assert!(invalid_type.validate().is_err()); - let dto = ImportDataKeyDTO{ + let dto = ImportDataKeyDTO { name: "Test".to_string(), description: "test descr".to_string(), visibility: Some("public".to_string()), @@ -695,7 +689,7 @@ mod tests { public_key: "1234".to_string(), private_key: "1234".to_string(), }; - let identity = UserIdentity{ + let identity = UserIdentity { email: "email1".to_string(), id: 1, csrf_generation_token: None, diff --git a/src/presentation/handler/control/model/datakey/mod.rs b/src/presentation/handler/control/model/datakey/mod.rs index 6c29114..a07dce5 100644 --- a/src/presentation/handler/control/model/datakey/mod.rs +++ b/src/presentation/handler/control/model/datakey/mod.rs @@ -1 +1 @@ -pub mod dto; \ No newline at end of file +pub mod dto; diff --git a/src/presentation/handler/control/model/mod.rs b/src/presentation/handler/control/model/mod.rs index 151f572..ec2aed4 100644 --- a/src/presentation/handler/control/model/mod.rs +++ b/src/presentation/handler/control/model/mod.rs @@ -1,3 +1,3 @@ pub mod datakey; +pub mod token; pub mod user; -pub mod token; \ No newline at end of file diff --git a/src/presentation/handler/control/model/token/dto.rs b/src/presentation/handler/control/model/token/dto.rs index e61b76b..97c99a7 100644 --- a/src/presentation/handler/control/model/token/dto.rs +++ b/src/presentation/handler/control/model/token/dto.rs @@ -18,14 +18,13 @@ use serde::{Deserialize, Serialize}; use std::convert::From; use crate::domain::token::entity::Token; -use utoipa::{ToSchema}; +use utoipa::ToSchema; #[derive(Debug, Deserialize, ToSchema)] pub struct CreateTokenDTO { pub description: String, } - #[derive(Debug, Serialize, ToSchema)] pub struct TokenDTO { #[serde(skip_deserializing)] @@ -38,14 +37,12 @@ pub struct TokenDTO { #[serde(skip_deserializing)] pub create_at: String, #[serde(skip_deserializing)] - pub expire_at: String + pub expire_at: String, } impl CreateTokenDTO { pub fn new(description: String) -> CreateTokenDTO { - CreateTokenDTO { - description, - } + CreateTokenDTO { description } } } @@ -60,4 +57,4 @@ impl From for TokenDTO { create_at: token.create_at.to_string(), } } -} \ No newline at end of file +} diff --git a/src/presentation/handler/control/model/token/mod.rs b/src/presentation/handler/control/model/token/mod.rs index 6c29114..a07dce5 100644 --- a/src/presentation/handler/control/model/token/mod.rs +++ b/src/presentation/handler/control/model/token/mod.rs @@ -1 +1 @@ -pub mod dto; \ No newline at end of file +pub mod dto; diff --git a/src/presentation/handler/control/model/user/dto.rs b/src/presentation/handler/control/model/user/dto.rs index 48d2bc9..ecbccbc 100644 --- a/src/presentation/handler/control/model/user/dto.rs +++ b/src/presentation/handler/control/model/user/dto.rs @@ -14,41 +14,43 @@ * */ -use actix_web::{Result, HttpRequest, FromRequest, dev::Payload, dev::ServiceRequest, body::MessageBody,dev::ServiceResponse}; +use crate::application::user::UserService; +use crate::domain::user::entity::User; +use crate::util::error::Error::GeneratingKeyError; use crate::util::error::{Error, Result as SignatrustResult}; -use std::convert::TryInto; +use crate::util::key::generate_csrf_parent_token; use actix_identity::Identity; -use actix_web_lab::middleware::Next; -use actix_web::web; -use std::pin::Pin; -use futures::Future; -use serde::{Deserialize, Serialize}; -use std::convert::From; use actix_web::http::header::HeaderName; -use crate::application::user::UserService; -use crate::domain::user::entity::User; -use utoipa::{IntoParams, ToSchema}; -use validator::Validate; +use actix_web::web; +use actix_web::{ + body::MessageBody, dev::Payload, dev::ServiceRequest, dev::ServiceResponse, FromRequest, + HttpRequest, Result, +}; +use actix_web_lab::middleware::Next; use csrf::{AesGcmCsrfProtection, CsrfProtection}; use data_encoding::BASE64; +use futures::Future; use reqwest::header::HeaderValue; use reqwest::StatusCode; use secstr::SecVec; -use crate::util::error::Error::GeneratingKeyError; -use crate::util::key::generate_csrf_parent_token; +use serde::{Deserialize, Serialize}; +use std::convert::From; +use std::convert::TryInto; +use std::pin::Pin; +use utoipa::{IntoParams, ToSchema}; +use validator::Validate; pub const CSRF_HEADER_NAME: &str = "Xsrf-Token"; pub const AUTH_HEADER_NAME: &str = "Authorization"; pub const SET_COOKIE_HEADER: &str = "set-cookie"; - #[derive(Debug, Deserialize, Serialize, ToSchema, Clone)] pub struct UserIdentity { pub email: String, pub id: i32, //these two only exist when calling from OIDC login pub csrf_generation_token: Option>, - pub csrf_token: Option + pub csrf_token: Option, } impl UserIdentity { @@ -57,7 +59,7 @@ impl UserIdentity { id: id.id, email: id.email, csrf_token: None, - csrf_generation_token: None + csrf_generation_token: None, } } @@ -71,55 +73,73 @@ impl UserIdentity { id: id.id, email: id.email, csrf_generation_token: Some(random_token.to_vec()), - csrf_token: Some(token.b64_string()) + csrf_token: Some(token.b64_string()), }) } - pub fn generate_new_csrf_cookie(&self, protect_key: [u8; 32], ttl_seconds: i64) -> SignatrustResult { + pub fn generate_new_csrf_cookie( + &self, + protect_key: [u8; 32], + ttl_seconds: i64, + ) -> SignatrustResult { if self.csrf_generation_token.is_none() || self.csrf_token.is_none() { - return Err(GeneratingKeyError("csrf token is empty, cannot generate new csrf cookie".to_string())); + return Err(GeneratingKeyError( + "csrf token is empty, cannot generate new csrf cookie".to_string(), + )); } let protect = AesGcmCsrfProtection::from_key(protect_key); let generation_token: [u8; 64] = self.csrf_generation_token.clone().unwrap().try_into()?; - let cookie = protect.generate_cookie( - &generation_token, ttl_seconds)?; + let cookie = protect.generate_cookie(&generation_token, ttl_seconds)?; Ok(cookie.b64_string()) } pub fn csrf_cookie_valid(&self, protect_key: [u8; 32], value: &str) -> SignatrustResult { let protect = AesGcmCsrfProtection::from_key(protect_key); Ok(protect.verify_token_pair( - &protect.parse_token( - &BASE64.decode(self.csrf_token.clone().unwrap().as_bytes())?)?, - &protect.parse_cookie( - &BASE64.decode(value.as_bytes())?)?)) + &protect.parse_token(&BASE64.decode(self.csrf_token.clone().unwrap().as_bytes())?)?, + &protect.parse_cookie(&BASE64.decode(value.as_bytes())?)?, + )) } - pub async fn append_csrf_cookie(req: ServiceRequest, next: Next) -> core::result::Result, actix_web::error::Error> { + pub async fn append_csrf_cookie( + req: ServiceRequest, + next: Next, + ) -> core::result::Result, actix_web::error::Error> + { let mut response = next.call(req).await?; - if let Ok(identity) = Identity::from_request(response.request(), &mut Payload::None).into_inner() { - if let Ok(user_json) = identity.id() { - if let Ok(user) = serde_json::from_str::(&user_json) { - if response.status() == StatusCode::UNAUTHORIZED { - //only append csrf token in authorized response - return Ok(response); - } - //generate csrf cookie based on user token - if let Some(protect_key) = response.request().app_data::>>() { - if let Ok(protect_key_array) = protect_key.clone().unsecure().try_into() { - if let Ok(csrf_token) = user.generate_new_csrf_cookie(protect_key_array, 600) { - let http_header = response.headers_mut(); - http_header.insert( - HeaderName::from_static(SET_COOKIE_HEADER), - HeaderValue::from_str(&format!("{}={}; Secure; Path=/; Max-Age=600", CSRF_HEADER_NAME, csrf_token)).unwrap(), - ); - } else { - warn!("failed to generate csrf token in middleware"); - } - } - } - } - } + if let Ok(identity) = + Identity::from_request(response.request(), &mut Payload::None).into_inner() + { + if let Ok(user_json) = identity.id() { + if let Ok(user) = serde_json::from_str::(&user_json) { + if response.status() == StatusCode::UNAUTHORIZED { + //only append csrf token in authorized response + return Ok(response); + } + //generate csrf cookie based on user token + if let Some(protect_key) = + response.request().app_data::>>() + { + if let Ok(protect_key_array) = protect_key.clone().unsecure().try_into() { + if let Ok(csrf_token) = + user.generate_new_csrf_cookie(protect_key_array, 600) + { + let http_header = response.headers_mut(); + http_header.insert( + HeaderName::from_static(SET_COOKIE_HEADER), + HeaderValue::from_str(&format!( + "{}={}; Secure; Path=/; Max-Age=600", + CSRF_HEADER_NAME, csrf_token + )) + .unwrap(), + ); + } else { + warn!("failed to generate csrf token in middleware"); + } + } + } + } + } } Ok(response) @@ -130,7 +150,7 @@ impl From for User { fn from(id: UserIdentity) -> Self { User { id: id.id, - email: id.email + email: id.email, } } } @@ -158,7 +178,11 @@ impl FromRequest for UserIdentity { None => { if let Some(value) = req.headers().get(AUTH_HEADER_NAME) { if let Some(user_service) = req.app_data::>() { - if let Ok(user) = user_service.get_ref().validate_token(value.to_str().unwrap()).await { + if let Ok(user) = user_service + .get_ref() + .validate_token(value.to_str().unwrap()) + .await + { return Ok(UserIdentity::from_user(user)); } else { warn!("unable to find token record"); @@ -173,9 +197,12 @@ impl FromRequest for UserIdentity { Some(user) => { if let Some(protect_key) = req.app_data::>>() { if let Some(header) = req.headers().get(CSRF_HEADER_NAME) { - if let Ok(protect_key_array) = protect_key.clone().unsecure().try_into() { - if let Ok(true) = user.csrf_cookie_valid(protect_key_array, header.to_str().unwrap()) { - return Ok(user) + if let Ok(protect_key_array) = protect_key.clone().unsecure().try_into() + { + if let Ok(true) = user + .csrf_cookie_valid(protect_key_array, header.to_str().unwrap()) + { + return Ok(user); } else { warn!("csrf header is invalid"); } @@ -198,4 +225,3 @@ pub struct Code { #[validate(length(min = 1))] pub code: String, } - diff --git a/src/presentation/handler/control/model/user/mod.rs b/src/presentation/handler/control/model/user/mod.rs index 6c29114..a07dce5 100644 --- a/src/presentation/handler/control/model/user/mod.rs +++ b/src/presentation/handler/control/model/user/mod.rs @@ -1 +1 @@ -pub mod dto; \ No newline at end of file +pub mod dto; diff --git a/src/presentation/handler/control/user_handler.rs b/src/presentation/handler/control/user_handler.rs index 8ca7d63..833576e 100644 --- a/src/presentation/handler/control/user_handler.rs +++ b/src/presentation/handler/control/user_handler.rs @@ -14,11 +14,11 @@ * */ -use actix_web::{HttpResponse, Responder, Result, web, Scope, HttpRequest, HttpMessage}; -use crate::util::error::Error; use super::model::user::dto::UserIdentity; +use crate::util::error::Error; use actix_identity::Identity; use actix_web::cookie::Cookie; +use actix_web::{web, HttpMessage, HttpRequest, HttpResponse, Responder, Result, Scope}; use secstr::SecVec; use validator::Validate; @@ -42,7 +42,12 @@ use crate::presentation::handler::control::model::user::dto::{Code, CSRF_HEADER_ ) )] async fn login(user_service: web::Data) -> Result { - Ok(HttpResponse::Found().insert_header(("Location", user_service.into_inner().get_login_url().await?.as_str())).finish()) + Ok(HttpResponse::Found() + .insert_header(( + "Location", + user_service.into_inner().get_login_url().await?.as_str(), + )) + .finish()) } /// Get login user information @@ -84,7 +89,7 @@ async fn info(id: UserIdentity) -> Result { )] async fn logout(id: Identity) -> Result { id.logout(); - Ok( HttpResponse::NoContent().finish()) + Ok(HttpResponse::NoContent().finish()) } /// Callback API for OIDC provider @@ -105,24 +110,39 @@ async fn logout(id: Identity) -> Result { (status = 500, description = "Server internal error", body = ErrorMessage) ) )] -async fn callback(req: HttpRequest, user_service: web::Data, code: web::Query, protect_key: web::Data>) -> Result { +async fn callback( + req: HttpRequest, + user_service: web::Data, + code: web::Query, + protect_key: web::Data>, +) -> Result { code.validate()?; //generate csrf token and cookie - let protect_key_array:[u8; 32] = protect_key.unsecure().try_into()?; - let user_entity = UserIdentity::from_user_with_csrf_token(user_service.validate_user(&code.code).await?, protect_key_array)?; + let protect_key_array: [u8; 32] = protect_key.unsecure().try_into()?; + let user_entity = UserIdentity::from_user_with_csrf_token( + user_service.validate_user(&code.code).await?, + protect_key_array, + )?; match Identity::login(&req.extensions(), serde_json::to_string(&user_entity)?) { Ok(_) => { - let cookie = Cookie::build(CSRF_HEADER_NAME, user_entity.generate_new_csrf_cookie(protect_key_array, 3600)?) - .path("/") - .secure(true) - .same_site(actix_web::cookie::SameSite::Strict) - .expires(time::OffsetDateTime::now_utc() + time::Duration::seconds(3600)) - .finish(); - Ok(HttpResponse::Found().cookie(cookie).insert_header(("Location", "/")).finish()) - } - Err(err) => { - Err(Error::AuthError(format!("failed to get oidc token {}", err))) + let cookie = Cookie::build( + CSRF_HEADER_NAME, + user_entity.generate_new_csrf_cookie(protect_key_array, 3600)?, + ) + .path("/") + .secure(true) + .same_site(actix_web::cookie::SameSite::Strict) + .expires(time::OffsetDateTime::now_utc() + time::Duration::seconds(3600)) + .finish(); + Ok(HttpResponse::Found() + .cookie(cookie) + .insert_header(("Location", "/")) + .finish()) } + Err(err) => Err(Error::AuthError(format!( + "failed to get oidc token {}", + err + ))), } } @@ -145,8 +165,15 @@ async fn callback(req: HttpRequest, user_service: web::Data, co (status = 500, description = "Server internal error", body = ErrorMessage) ) )] -async fn new_token(user: UserIdentity, user_service: web::Data, token: web::Json) -> Result { - let token = user_service.into_inner().generate_token(&user, token.0).await?; +async fn new_token( + user: UserIdentity, + user_service: web::Data, + token: web::Json, +) -> Result { + let token = user_service + .into_inner() + .generate_token(&user, token.0) + .await?; Ok(HttpResponse::Created().json(TokenDTO::from(token))) } @@ -174,8 +201,15 @@ async fn new_token(user: UserIdentity, user_service: web::Data, (status = 500, description = "Server internal error", body = ErrorMessage) ) )] -async fn delete_token(user: UserIdentity, user_service: web::Data, id: web::Path) -> Result { - user_service.into_inner().delete_token(&user, id.parse::()?).await?; +async fn delete_token( + user: UserIdentity, + user_service: web::Data, + id: web::Path, +) -> Result { + user_service + .into_inner() + .delete_token(&user, id.parse::()?) + .await?; Ok(HttpResponse::Ok()) } @@ -198,7 +232,10 @@ async fn delete_token(user: UserIdentity, user_service: web::Data) -> Result { +async fn list_token( + user: UserIdentity, + user_service: web::Data, +) -> Result { let token = user_service.into_inner().get_token(&user).await?; let mut results = vec![]; for t in token.into_iter() { @@ -213,9 +250,10 @@ pub fn get_scope() -> Scope { .service(web::resource("/login").route(web::get().to(login))) .service(web::resource("/logout").route(web::post().to(logout))) .service(web::resource("/callback").route(web::get().to(callback))) - .service(web::resource("/api_keys") - .route(web::post().to(new_token)) - .route(web::get().to(list_token))) - .service( web::resource("/api_keys/{id}") - .route(web::delete().to(delete_token))) -} \ No newline at end of file + .service( + web::resource("/api_keys") + .route(web::post().to(new_token)) + .route(web::get().to(list_token)), + ) + .service(web::resource("/api_keys/{id}").route(web::delete().to(delete_token))) +} diff --git a/src/presentation/handler/data/health_handler.rs b/src/presentation/handler/data/health_handler.rs index 9239360..2710d3c 100644 --- a/src/presentation/handler/data/health_handler.rs +++ b/src/presentation/handler/data/health_handler.rs @@ -16,10 +16,11 @@ use std::pin::Pin; pub mod health { tonic::include_proto!("grpc.health.v1"); } -use tokio_stream::{Stream, once}; +use tokio_stream::{once, Stream}; use health::{ - health_server::Health, health_server::HealthServer, HealthCheckRequest, HealthCheckResponse, health_check_response::ServingStatus, + health_check_response::ServingStatus, health_server::Health, health_server::HealthServer, + HealthCheckRequest, HealthCheckResponse, }; use tonic::{Request, Response, Status}; @@ -34,8 +35,7 @@ impl HealthHandler { } #[tonic::async_trait] -impl Health for HealthHandler -{ +impl Health for HealthHandler { type WatchStream = ResponseStream; async fn check( &self, @@ -57,11 +57,9 @@ impl Health for HealthHandler let reply_stream = once(Ok(reply)); Ok(Response::new(Box::pin(reply_stream))) } - } -pub fn get_grpc_handler() -> HealthServer -{ +pub fn get_grpc_handler() -> HealthServer { let app = HealthHandler::new(); HealthServer::new(app) } diff --git a/src/presentation/handler/data/mod.rs b/src/presentation/handler/data/mod.rs index 25588d1..c357f4a 100644 --- a/src/presentation/handler/data/mod.rs +++ b/src/presentation/handler/data/mod.rs @@ -1,2 +1,2 @@ +pub mod health_handler; pub mod sign_handler; -pub mod health_handler; \ No newline at end of file diff --git a/src/presentation/handler/data/sign_handler.rs b/src/presentation/handler/data/sign_handler.rs index f870327..7d2295c 100644 --- a/src/presentation/handler/data/sign_handler.rs +++ b/src/presentation/handler/data/sign_handler.rs @@ -21,15 +21,15 @@ pub mod signatrust { } use tokio_stream::StreamExt; -use signatrust::{ - signatrust_server::Signatrust, signatrust_server::SignatrustServer, SignStreamRequest, - SignStreamResponse, GetKeyInfoRequest, GetKeyInfoResponse -}; -use tonic::{Request, Response, Status, Streaming}; use crate::application::datakey::KeyService; use crate::application::user::UserService; use crate::util::error::Error; use crate::util::error::Result as SignatrustResult; +use signatrust::{ + signatrust_server::Signatrust, signatrust_server::SignatrustServer, GetKeyInfoRequest, + GetKeyInfoResponse, SignStreamRequest, SignStreamResponse, +}; +use tonic::{Request, Response, Status, Streaming}; pub struct SignHandler where @@ -48,16 +48,27 @@ where pub fn new(key_service: K, user_service: U) -> Self { SignHandler { key_service, - user_service + user_service, } } - async fn validate_key_token_matched(&self, token: Option, name: &str) -> SignatrustResult<()> { + async fn validate_key_token_matched( + &self, + token: Option, + name: &str, + ) -> SignatrustResult<()> { let names: Vec<_> = name.split(':').collect(); if names.len() <= 1 { - return Ok(()) + return Ok(()); } - if token.is_none() || !self.user_service.validate_token_and_email(names[0], &token.unwrap()).await? { - return Err(Error::AuthError("user token and email unmatched".to_string())) + if token.is_none() + || !self + .user_service + .validate_token_and_email(names[0], &token.unwrap()) + .await? + { + return Err(Error::AuthError( + "user token and email unmatched".to_string(), + )); } Ok(()) } @@ -72,30 +83,32 @@ where async fn get_key_info( &self, request: Request, - ) -> Result, Status> - { + ) -> Result, Status> { let request = request.into_inner(); //perform token validation on private keys - if let Err(err) = self.validate_key_token_matched(request.token, &request.key_id).await { + if let Err(err) = self + .validate_key_token_matched(request.token, &request.key_id) + .await + { return Ok(Response::new(GetKeyInfoResponse { attributes: HashMap::new(), error: err.to_string(), - })) - } - return match self.key_service.get_by_type_and_name(request.key_type, request.key_id).await { - Ok(datakey) => { - Ok(Response::new(GetKeyInfoResponse { - attributes: datakey.attributes, - error: "".to_string(), - })) - } - Err(err) => { - Ok(Response::new(GetKeyInfoResponse { - attributes: HashMap::new(), - error: err.to_string(), - })) - } + })); } + return match self + .key_service + .get_by_type_and_name(request.key_type, request.key_id) + .await + { + Ok(datakey) => Ok(Response::new(GetKeyInfoResponse { + attributes: datakey.attributes, + error: "".to_string(), + })), + Err(err) => Ok(Response::new(GetKeyInfoResponse { + attributes: HashMap::new(), + error: err.to_string(), + })), + }; } async fn sign_stream( &self, @@ -120,30 +133,36 @@ where return Ok(Response::new(SignStreamResponse { signature: vec![], error: err.to_string(), - })) + })); } - debug!("begin to sign key_type :{} key_name: {}", key_type, key_name); - match self.key_service.sign(key_type, key_name, &options, data).await { - Ok(content) => { - Ok(Response::new(SignStreamResponse { - signature: content, - error: "".to_string() - })) - } - Err(err) => { - Ok(Response::new(SignStreamResponse { - signature: vec![], - error: err.to_string(), - })) - } + debug!( + "begin to sign key_type :{} key_name: {}", + key_type, key_name + ); + match self + .key_service + .sign(key_type, key_name, &options, data) + .await + { + Ok(content) => Ok(Response::new(SignStreamResponse { + signature: content, + error: "".to_string(), + })), + Err(err) => Ok(Response::new(SignStreamResponse { + signature: vec![], + error: err.to_string(), + })), } } } -pub fn get_grpc_handler(key_service: K, user_service: U) -> SignatrustServer> +pub fn get_grpc_handler( + key_service: K, + user_service: U, +) -> SignatrustServer> where K: KeyService + 'static, - U: UserService + 'static + U: UserService + 'static, { let app = SignHandler::new(key_service, user_service); SignatrustServer::new(app) diff --git a/src/presentation/mod.rs b/src/presentation/mod.rs index 25c2756..918a0f0 100644 --- a/src/presentation/mod.rs +++ b/src/presentation/mod.rs @@ -1,2 +1,2 @@ -pub mod server; pub mod handler; +pub mod server; diff --git a/src/presentation/server/control_server.rs b/src/presentation/server/control_server.rs index b4a4940..6311b62 100644 --- a/src/presentation/server/control_server.rs +++ b/src/presentation/server/control_server.rs @@ -14,41 +14,41 @@ * */ +use actix_identity::IdentityMiddleware; +use actix_limitation::{Limiter, RateLimiter}; +use actix_session::{config::PersistentSession, storage::RedisSessionStore, SessionMiddleware}; +use actix_web::{cookie::Key, middleware, web, App, HttpServer}; +use config::Config; +use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; use std::net::SocketAddr; +use std::sync::{Arc, RwLock}; +use std::time::Duration; +use time::Duration as timeDuration; use utoipa::{ openapi::security::{ApiKey, ApiKeyValue, SecurityScheme}, Modify, OpenApi, }; use utoipa_swagger_ui::SwaggerUi; -use std::sync::{Arc, RwLock}; -use actix_web::{App, HttpServer, middleware, web, cookie::Key}; -use config::{Config}; -use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; -use actix_identity::{IdentityMiddleware}; -use actix_session::{config::PersistentSession, storage::RedisSessionStore, SessionMiddleware}; -use actix_limitation::{Limiter, RateLimiter}; -use time::Duration as timeDuration; -use std::time::Duration; use crate::infra::database::model::datakey::repository as datakeyRepository; use crate::infra::database::pool::{create_pool, get_db_connection}; use crate::presentation::handler::control::*; -use actix_web::{dev::ServiceRequest}; +use crate::util::error::{Error, Result}; use actix_web::cookie::SameSite; +use actix_web::dev::ServiceRequest; use secstr::SecVec; use tokio_util::sync::CancellationToken; -use crate::util::error::{Error, Result}; use crate::application::datakey::{DBKeyService, KeyService}; -use crate::infra::database::model::token::repository::TokenRepository; -use crate::infra::database::model::user::repository::UserRepository; -use crate::infra::sign_backend::factory::SignBackendFactory; use crate::application::user::{DBUserService, UserService}; use crate::domain::datakey::entity::{DataKey, KeyState}; use crate::domain::token::entity::Token; use crate::domain::user::entity::User; -use crate::presentation::handler::control::model::token::dto::{CreateTokenDTO}; -use crate::presentation::handler::control::model::user::dto::{UserIdentity}; +use crate::infra::database::model::token::repository::TokenRepository; +use crate::infra::database::model::user::repository::UserRepository; +use crate::infra::sign_backend::factory::SignBackendFactory; +use crate::presentation::handler::control::model::token::dto::CreateTokenDTO; +use crate::presentation::handler::control::model::user::dto::UserIdentity; use crate::util::key::{file_exists, truncate_string_to_protect_key}; pub struct ControlServer { @@ -118,26 +118,27 @@ impl Modify for SecurityAddon { struct ControlApiDoc; impl ControlServer { - pub async fn new(server_config: Arc>, cancel_token: CancellationToken) -> Result { + pub async fn new( + server_config: Arc>, + cancel_token: CancellationToken, + ) -> Result { let database = server_config.read()?.get_table("database")?; create_pool(&database).await?; - let data_repository = datakeyRepository::DataKeyRepository::new( - get_db_connection()?, - ); - let sign_backend = SignBackendFactory::new_engine( - server_config.clone(), get_db_connection()?).await?; + let data_repository = datakeyRepository::DataKeyRepository::new(get_db_connection()?); + let sign_backend = + SignBackendFactory::new_engine(server_config.clone(), get_db_connection()?).await?; //initialize repos let user_repo = UserRepository::new(get_db_connection()?); let token_repo = TokenRepository::new(get_db_connection()?); //initialize the service - let user_service = Arc::new( - DBUserService::new( - user_repo, token_repo, - server_config.clone())?) as Arc; - let key_service = Arc::new( - DBKeyService::new( - data_repository, sign_backend)) as Arc; + let user_service = Arc::new(DBUserService::new( + user_repo, + token_repo, + server_config.clone(), + )?) as Arc; + let key_service = + Arc::new(DBKeyService::new(data_repository, sign_backend)) as Arc; let server = ControlServer { user_service, key_service, @@ -158,22 +159,30 @@ impl ControlServer { .read()? .get_string("control-server.server_port")? ) - .parse()?; + .parse()?; - let key = self.server_config.read()?.get_string("control-server.cookie_key")?; - let redis_connection = self.server_config.read()?.get_string("control-server.redis_connection")?; + let key = self + .server_config + .read()? + .get_string("control-server.cookie_key")?; + let redis_connection = self + .server_config + .read()? + .get_string("control-server.redis_connection")?; info!("control server starts"); // Start http server - let user_service = web::Data::from( - self.user_service.clone()); - let key_service = web::Data::from( - self.key_service.clone()); + let user_service = web::Data::from(self.user_service.clone()); + let key_service = web::Data::from(self.key_service.clone()); key_service.start_key_rotate_loop(self.cancel_token.clone())?; key_service.start_key_plugin_maintenance( self.cancel_token.clone(), - self.server_config.read()?.get_string("control-server.server_port")?.parse()?)?; + self.server_config + .read()? + .get_string("control-server.server_port")? + .parse()?, + )?; //prepare redis store let store = RedisSessionStore::new(&redis_connection).await?; @@ -188,14 +197,20 @@ impl ControlServer { } None }) - .limit(self.server_config.read()?.get_string("control-server.limits_per_minute")?.parse()?) + .limit( + self.server_config + .read()? + .get_string("control-server.limits_per_minute")? + .parse()?, + ) .period(Duration::from_secs(60)) .build() .unwrap(), ); let openapi = ControlApiDoc::openapi(); - let csrf_protect_key = web::Data::new(SecVec::new(truncate_string_to_protect_key(&key).to_vec())); + let csrf_protect_key = + web::Data::new(SecVec::new(truncate_string_to_protect_key(&key).to_vec())); let http_server = HttpServer::new(move || { App::new() @@ -211,9 +226,10 @@ impl ControlServer { .wrap(RateLimiter::default()) // session handler .wrap( - SessionMiddleware::builder( - store.clone(), Key::from(key.as_bytes())) - .session_lifecycle(PersistentSession::default().session_ttl(timeDuration::hours(1))) + SessionMiddleware::builder(store.clone(), Key::from(key.as_bytes())) + .session_lifecycle( + PersistentSession::default().session_ttl(timeDuration::hours(1)), + ) .cookie_name("Signatrust".to_owned()) .cookie_secure(true) .cookie_domain(None) @@ -227,25 +243,42 @@ impl ControlServer { .app_data(limiter.clone()) //open api document .service( - SwaggerUi::new("/api/swagger-ui/{_:.*}").url("/api-doc/openapi.json", openapi.clone()), + SwaggerUi::new("/api/swagger-ui/{_:.*}") + .url("/api-doc/openapi.json", openapi.clone()), + ) + .service( + web::scope("/api/v1") + .service(user_handler::get_scope()) + .service(datakey_handler::get_scope()), ) - .service(web::scope("/api/v1") - .service(user_handler::get_scope()) - .service(datakey_handler::get_scope())) - .service(web::scope("/api") - .service(health_handler::get_scope())) + .service(web::scope("/api").service(health_handler::get_scope())) }); - let tls_cert = self.server_config.read()?.get_string("tls_cert").unwrap_or(String::new()).to_string(); - let tls_key = self.server_config.read()?.get_string("tls_key").unwrap_or(String::new()).to_string(); + let tls_cert = self + .server_config + .read()? + .get_string("tls_cert") + .unwrap_or(String::new()) + .to_string(); + let tls_key = self + .server_config + .read()? + .get_string("tls_key") + .unwrap_or(String::new()) + .to_string(); if tls_cert.is_empty() || tls_key.is_empty() { info!("tls key and cert not configured, control server tls will be disabled"); http_server.bind(addr)?.run().await?; } else { if !file_exists(&tls_cert) || !file_exists(&tls_key) { - return Err(Error::FileFoundError(format!("tls cert: {} or key: {} file not found", tls_key, tls_cert))); + return Err(Error::FileFoundError(format!( + "tls cert: {} or key: {} file not found", + tls_key, tls_cert + ))); } let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder.set_private_key_file(tls_key, SslFiletype::PEM).unwrap(); + builder + .set_private_key_file(tls_key, SslFiletype::PEM) + .unwrap(); builder.set_certificate_chain_file(tls_cert).unwrap(); http_server.bind_openssl(addr, builder)?.run().await?; } @@ -255,16 +288,21 @@ impl ControlServer { //used for control admin cmd pub async fn create_user_token(&self, user: User) -> Result { let user = self.user_service.save(user).await?; - self.user_service.generate_token( - &UserIdentity::from_user(user.clone()), - CreateTokenDTO::new("default admin token".to_owned())).await + self.user_service + .generate_token( + &UserIdentity::from_user(user.clone()), + CreateTokenDTO::new("default admin token".to_owned()), + ) + .await } //used for control admin cmd pub async fn create_keys(&self, data: &mut DataKey, user: UserIdentity) -> Result { let key = self.key_service.create(user.clone(), data).await?; if data.key_state == KeyState::Disabled { - self.key_service.enable(Some(user), format!("{}", key.id)).await?; + self.key_service + .enable(Some(user), format!("{}", key.id)) + .await?; } Ok(key) } diff --git a/src/presentation/server/data_server.rs b/src/presentation/server/data_server.rs index eb49c48..0c7dfed 100644 --- a/src/presentation/server/data_server.rs +++ b/src/presentation/server/data_server.rs @@ -14,19 +14,14 @@ * */ +use crate::application::datakey::DBKeyService; +use crate::application::user::DBUserService; +use config::Config; use std::net::SocketAddr; use std::sync::{Arc, RwLock}; -use config::{Config}; use tokio::fs; use tokio_util::sync::CancellationToken; -use tonic::{ - transport::{ - Certificate, - Identity, Server, ServerTlsConfig, - }, -}; -use crate::application::datakey::{DBKeyService}; -use crate::application::user::{DBUserService}; +use tonic::transport::{Certificate, Identity, Server, ServerTlsConfig}; use crate::infra::database::model::datakey::repository; use crate::infra::database::model::token::repository::TokenRepository; @@ -34,14 +29,12 @@ use crate::infra::database::model::user::repository::UserRepository; use crate::infra::database::pool::{create_pool, get_db_connection}; use crate::infra::sign_backend::factory::SignBackendFactory; - -use crate::presentation::handler::data::sign_handler::get_grpc_handler as sign_grpc_handler; use crate::presentation::handler::data::health_handler::get_grpc_handler as health_grpc_handler; +use crate::presentation::handler::data::sign_handler::get_grpc_handler as sign_grpc_handler; use crate::util::error::{Error, Result}; use crate::util::key::file_exists; -pub struct DataServer -{ +pub struct DataServer { server_config: Arc>, cancel_token: CancellationToken, server_identity: Option, @@ -49,7 +42,10 @@ pub struct DataServer } impl DataServer { - pub async fn new(server_config: Arc>, cancel_token: CancellationToken) -> Result { + pub async fn new( + server_config: Arc>, + cancel_token: CancellationToken, + ) -> Result { let database = server_config.read()?.get_table("database")?; create_pool(&database).await?; let mut server = DataServer { @@ -63,20 +59,39 @@ impl DataServer { } async fn load(&mut self) -> Result<()> { - let ca_root = self.server_config.read()?.get("ca_root").unwrap_or(String::new()).to_string(); - let tls_cert = self.server_config.read()?.get("tls_cert").unwrap_or(String::new()).to_string(); - let tls_key = self.server_config.read()?.get("tls_key").unwrap_or(String::new()).to_string(); + let ca_root = self + .server_config + .read()? + .get("ca_root") + .unwrap_or(String::new()) + .to_string(); + let tls_cert = self + .server_config + .read()? + .get("tls_cert") + .unwrap_or(String::new()) + .to_string(); + let tls_key = self + .server_config + .read()? + .get("tls_key") + .unwrap_or(String::new()) + .to_string(); if ca_root.is_empty() || tls_cert.is_empty() || tls_key.is_empty() { info!("tls key and cert not configured, data server tls will be disabled"); return Ok(()); } if !file_exists(&ca_root) || !file_exists(&tls_cert) || !file_exists(&tls_key) { - return Err(Error::FileFoundError(format!("ca root: {} or tls cert: {} or key: {} file not found", ca_root, tls_key, tls_cert))); + return Err(Error::FileFoundError(format!( + "ca root: {} or tls cert: {} or key: {} file not found", + ca_root, tls_key, tls_cert + ))); } - self.ca_cert = Some( - Certificate::from_pem(fs::read(ca_root).await?)); - self.server_identity = Some(Identity::from_pem(fs::read(tls_cert).await?, - fs::read(tls_key).await?)); + self.ca_cert = Some(Certificate::from_pem(fs::read(ca_root).await?)); + self.server_identity = Some(Identity::from_pem( + fs::read(tls_cert).await?, + fs::read(tls_key).await?, + )); Ok(()) } @@ -106,10 +121,10 @@ impl DataServer { let mut server = Server::builder(); info!("data server starts"); - let sign_backend = SignBackendFactory::new_engine( - self.server_config.clone(), get_db_connection()?).await?; - let data_repository = repository::DataKeyRepository::new( - get_db_connection()?); + let sign_backend = + SignBackendFactory::new_engine(self.server_config.clone(), get_db_connection()?) + .await?; + let data_repository = repository::DataKeyRepository::new(get_db_connection()?); let key_service = DBKeyService::new(data_repository, sign_backend); let user_repo = UserRepository::new(get_db_connection()?); let token_repo = TokenRepository::new(get_db_connection()?); @@ -117,7 +132,11 @@ impl DataServer { if let Some(identity) = self.server_identity.clone() { server - .tls_config(ServerTlsConfig::new().identity(identity).client_ca_root(self.ca_cert.clone().unwrap()))? + .tls_config( + ServerTlsConfig::new() + .identity(identity) + .client_ca_root(self.ca_cert.clone().unwrap()), + )? .add_service(sign_grpc_handler(key_service, user_service)) .add_service(health_grpc_handler()) .serve_with_shutdown(addr, self.shutdown_signal()) diff --git a/src/util/cache.rs b/src/util/cache.rs index 3d8695c..ace9301 100644 --- a/src/util/cache.rs +++ b/src/util/cache.rs @@ -14,50 +14,51 @@ * */ -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::RwLock; -use crate::util::error::{Error, Result}; use crate::domain::datakey::entity::DataKey; use crate::domain::user::entity::User; +use crate::util::error::{Error, Result}; use chrono::{DateTime, Duration, Utc}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; const DATAKEY_EXPIRE_SECOND: i64 = 10 * 60; const USER_EXPIRE_SECOND: i64 = 60 * 60; - - #[derive(Clone)] -pub struct TimedFixedSizeCache -{ +pub struct TimedFixedSizeCache { cached_keys: Arc>>, cached_users: Arc>>, key_size: Option, user_size: Option, key_expire: Option, - user_expire: Option + user_expire: Option, } #[derive(Clone)] pub struct CachedDatakey { time: DateTime, - key: DataKey + key: DataKey, } #[derive(Clone)] pub struct CachedUser { time: DateTime, - user: User + user: User, } -impl TimedFixedSizeCache -{ - pub fn new(key_size: Option, user_size: Option, key_expire: Option, user_expire: Option) -> Self { +impl TimedFixedSizeCache { + pub fn new( + key_size: Option, + user_size: Option, + key_expire: Option, + user_expire: Option, + ) -> Self { Self { cached_keys: Arc::new(RwLock::new(HashMap::new())), cached_users: Arc::new(RwLock::new(HashMap::new())), key_size, user_size, key_expire: key_expire.or(Some(DATAKEY_EXPIRE_SECOND)), - user_expire: user_expire.or(Some(USER_EXPIRE_SECOND)) + user_expire: user_expire.or(Some(USER_EXPIRE_SECOND)), } } @@ -78,9 +79,17 @@ impl TimedFixedSizeCache self.cached_users.write().await.clear(); } } else { - return Err(Error::UnsupportedTypeError("user cache not enabled".to_string())) + return Err(Error::UnsupportedTypeError( + "user cache not enabled".to_string(), + )); } - self.cached_users.write().await.insert(identity.to_owned(), CachedUser{time: Utc::now(), user} ); + self.cached_users.write().await.insert( + identity.to_owned(), + CachedUser { + time: Utc::now(), + user, + }, + ); Ok(()) } @@ -101,12 +110,17 @@ impl TimedFixedSizeCache self.cached_keys.write().await.clear(); } } else { - return Err(Error::UnsupportedTypeError("datakey cache not enabled".to_string())) + return Err(Error::UnsupportedTypeError( + "datakey cache not enabled".to_string(), + )); } - self.cached_keys.write().await.insert(identity.to_owned(), CachedDatakey{ - time: Utc::now(), - key: datakey, - } ); + self.cached_keys.write().await.insert( + identity.to_owned(), + CachedDatakey { + time: Utc::now(), + key: datakey, + }, + ); Ok(()) } @@ -119,11 +133,13 @@ impl TimedFixedSizeCache } pub async fn update_sign_datakey(&self, id_or_name: &str, datakey: DataKey) -> Result<()> { - self.update_datakey(&self.get_sign_identity(id_or_name), datakey).await + self.update_datakey(&self.get_sign_identity(id_or_name), datakey) + .await } pub async fn update_read_datakey(&self, id_or_name: &str, datakey: DataKey) -> Result<()> { - self.update_datakey(&self.get_read_identity(id_or_name), datakey).await + self.update_datakey(&self.get_read_identity(id_or_name), datakey) + .await } fn get_sign_identity(&self, key_name: &str) -> String { @@ -137,19 +153,19 @@ impl TimedFixedSizeCache #[cfg(test)] mod test { - use std::thread::sleep; - use crate::domain::datakey::entity::KeyType::OpenPGP; use super::*; + use crate::domain::datakey::entity::KeyType::OpenPGP; use crate::util::error::Result; + use std::thread::sleep; #[tokio::test] - async fn test_user_cache() ->Result<()> { - let user_cache = TimedFixedSizeCache::new(None,Some(1), None, None); - let user1 = User{ + async fn test_user_cache() -> Result<()> { + let user_cache = TimedFixedSizeCache::new(None, Some(1), None, None); + let user1 = User { id: 1, email: "fake_email@gmail.com".to_string(), }; - let user2 = User{ + let user2 = User { id: 2, email: "fake_email@gmail.com".to_string(), }; @@ -163,11 +179,14 @@ mod test { assert_eq!(user_cache.get_user(&identity1).await, None); assert_eq!(user_cache.get_user(&identity2).await, Some(user2.clone())); - let user_cache = TimedFixedSizeCache::new(None,None, None, None); + let user_cache = TimedFixedSizeCache::new(None, None, None, None); assert_eq!(user_cache.get_user(&identity1).await, None); - assert!(user_cache.update_user(&identity2, user2.clone()).await.is_err()); + assert!(user_cache + .update_user(&identity2, user2.clone()) + .await + .is_err()); - let user_cache = TimedFixedSizeCache::new(None,Some(1), None, Some(1)); + let user_cache = TimedFixedSizeCache::new(None, Some(1), None, Some(1)); assert_eq!(user_cache.update_user(&identity2, user2).await?, ()); sleep(Duration::seconds(2).to_std()?); assert_eq!(user_cache.get_user(&identity2).await, None); @@ -175,9 +194,9 @@ mod test { } #[tokio::test] - async fn test_datakey_cache() ->Result<()> { - let key_cache = TimedFixedSizeCache::new(Some(2),None, None, None); - let datakey1 = DataKey{ + async fn test_datakey_cache() -> Result<()> { + let key_cache = TimedFixedSizeCache::new(Some(2), None, None, None); + let datakey1 = DataKey { id: 1, name: "fake datakey1".to_string(), visibility: Default::default(), @@ -199,7 +218,7 @@ mod test { request_revoke_users: None, parent_key: None, }; - let datakey2 = DataKey{ + let datakey2 = DataKey { id: 2, name: "fake datakey2".to_string(), visibility: Default::default(), @@ -224,27 +243,52 @@ mod test { let identity1 = "datakey1"; let identity2 = "1"; assert_eq!(key_cache.get_read_datakey(&identity1).await, None); - assert_eq!(key_cache.update_read_datakey(&identity1, datakey1.clone()).await?, ()); - assert_eq!(key_cache.get_read_datakey(&identity1).await, Some(datakey1.clone())); + assert_eq!( + key_cache + .update_read_datakey(&identity1, datakey1.clone()) + .await?, + () + ); + assert_eq!( + key_cache.get_read_datakey(&identity1).await, + Some(datakey1.clone()) + ); assert_eq!(key_cache.get_sign_datakey(&identity1).await, None); assert_eq!(key_cache.get_datakey(&identity1).await, None); - assert_eq!(key_cache.update_sign_datakey(&identity1, datakey1.clone()).await?, ()); - assert_eq!(key_cache.get_sign_datakey(&identity1).await, Some(datakey1.clone())); - assert_eq!(key_cache.update_sign_datakey(&identity2, datakey2.clone()).await?, ()); - assert_eq!(key_cache.get_sign_datakey(&identity2).await, Some(datakey2.clone())); + assert_eq!( + key_cache + .update_sign_datakey(&identity1, datakey1.clone()) + .await?, + () + ); + assert_eq!( + key_cache.get_sign_datakey(&identity1).await, + Some(datakey1.clone()) + ); + assert_eq!( + key_cache + .update_sign_datakey(&identity2, datakey2.clone()) + .await?, + () + ); + assert_eq!( + key_cache.get_sign_datakey(&identity2).await, + Some(datakey2.clone()) + ); assert_eq!(key_cache.get_sign_datakey(&identity1).await, None); assert_eq!(key_cache.get_read_datakey(&identity1).await, None); - - let key_cache = TimedFixedSizeCache::new(None,None, None, None); + let key_cache = TimedFixedSizeCache::new(None, None, None, None); assert_eq!(key_cache.get_datakey(&identity1).await, None); - assert!(key_cache.update_datakey(&identity1, datakey1.clone()).await.is_err()); + assert!(key_cache + .update_datakey(&identity1, datakey1.clone()) + .await + .is_err()); - let key_cache = TimedFixedSizeCache::new(Some(1),None, Some(1), Some(1)); + let key_cache = TimedFixedSizeCache::new(Some(1), None, Some(1), Some(1)); assert_eq!(key_cache.update_datakey(&identity1, datakey1).await?, ()); sleep(Duration::seconds(2).to_std()?); assert_eq!(key_cache.get_datakey(&identity1).await, None); Ok(()) } - -} \ No newline at end of file +} diff --git a/src/util/config.rs b/src/util/config.rs index 12b5b71..2a4bcbd 100644 --- a/src/util/config.rs +++ b/src/util/config.rs @@ -16,11 +16,11 @@ use crate::util::error::Result; use config::{Config, File, FileFormat}; -use notify::{Event, RecommendedWatcher, RecursiveMode, Watcher, Error}; +use notify::{Error, Event, RecommendedWatcher, RecursiveMode, Watcher}; use std::path::Path; -use tokio::sync::mpsc; use std::sync::{Arc, RwLock}; use std::time::Duration; +use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; pub struct ServerConfig { @@ -31,9 +31,12 @@ pub struct ServerConfig { impl ServerConfig { pub fn new(path: String) -> ServerConfig { let builder = Config::builder() - .set_default("tls_cert", "").expect("tls cert default to empty") - .set_default("tls_key", "").expect("tls key default to empty") - .set_default("ca_root", "").expect("ca root default to empty") + .set_default("tls_cert", "") + .expect("tls cert default to empty") + .set_default("tls_key", "") + .expect("tls key default to empty") + .set_default("ca_root", "") + .expect("ca root default to empty") .add_source(File::new(path.as_str(), FileFormat::Toml)); let config = builder.build().expect("load configuration file"); ServerConfig { diff --git a/src/util/error.rs b/src/util/error.rs index 448b2f0..86197dd 100644 --- a/src/util/error.rs +++ b/src/util/error.rs @@ -14,38 +14,38 @@ * */ -use std::array::TryFromSliceError; -use std::convert::Infallible; +use actix_web::cookie::KeyError; +use actix_web::{HttpResponse, ResponseError}; +use anyhow::Error as AnyhowError; +use bincode::error::{DecodeError, EncodeError}; +use chrono::{OutOfRangeError, ParseError}; use config::ConfigError; +use csrf::CsrfError; +use efi_signer::error::Error as EFIError; +use openidconnect::url::ParseError as OIDCParseError; +use openidconnect::ConfigurationError; +use openidconnect::UserInfoError; +use openssl::error::ErrorStack; use pgp::composed::key::SecretKeyParamsBuilderError; use pgp::errors::Error as PGPError; use reqwest::header::{InvalidHeaderValue, ToStrError as StrError}; use reqwest::Error as RequestError; +use rpm::Error as RPMError; +use sea_orm::DbErr; +use serde::{Deserialize, Serialize}; use serde_json::Error as SerdeError; use sqlx::Error as SqlxError; +use std::array::TryFromSliceError; +use std::convert::Infallible; use std::io::Error as IOError; use std::net::AddrParseError; use std::num::ParseIntError; use std::string::FromUtf8Error; use std::sync::PoisonError; -use rpm::Error as RPMError; use thiserror::Error as ThisError; use tonic::transport::Error as TonicError; -use bincode::error::{EncodeError, DecodeError}; -use chrono::{OutOfRangeError, ParseError}; -use actix_web::{ResponseError, HttpResponse}; +use utoipa::ToSchema; use validator::ValidationErrors; -use serde::{Deserialize, Serialize}; -use openssl::error::ErrorStack; -use actix_web::cookie::KeyError; -use openidconnect::url::ParseError as OIDCParseError; -use openidconnect::ConfigurationError; -use openidconnect::UserInfoError; -use anyhow::Error as AnyhowError; -use csrf::CsrfError; -use utoipa::{ToSchema}; -use efi_signer::error::Error as EFIError; -use sea_orm::DbErr; pub type Result = std::result::Result; @@ -138,7 +138,7 @@ pub enum Error { #[derive(Deserialize, Serialize, ToSchema)] pub struct ErrorMessage { - detail: String + detail: String, } impl ResponseError for Error { @@ -146,37 +146,37 @@ impl ResponseError for Error { match self { Error::ParameterError(_) | Error::UnsupportedTypeError(_) => { warn!("parameter error: {}", self.to_string()); - HttpResponse::BadRequest().json(ErrorMessage{ - detail: self.to_string() + HttpResponse::BadRequest().json(ErrorMessage { + detail: self.to_string(), }) } Error::NotFoundError => { warn!("record not found error: {}", self.to_string()); - HttpResponse::NotFound().json(ErrorMessage{ - detail: self.to_string() + HttpResponse::NotFound().json(ErrorMessage { + detail: self.to_string(), }) } Error::UnauthorizedError => { warn!("unauthorized: {}", self.to_string()); - HttpResponse::Unauthorized().json(ErrorMessage{ - detail: self.to_string() + HttpResponse::Unauthorized().json(ErrorMessage { + detail: self.to_string(), }) } Error::ActionsNotAllowedError(_) => { warn!("unprivileged: {}", self.to_string()); - HttpResponse::Forbidden().json(ErrorMessage{ - detail: self.to_string() + HttpResponse::Forbidden().json(ErrorMessage { + detail: self.to_string(), }) } Error::UnprivilegedError => { warn!("unprivileged: {}", self.to_string()); - HttpResponse::Forbidden().json(ErrorMessage{ - detail: self.to_string() + HttpResponse::Forbidden().json(ErrorMessage { + detail: self.to_string(), }) } _ => { warn!("internal error: {}", self.to_string()); - HttpResponse::InternalServerError().json(ErrorMessage{ + HttpResponse::InternalServerError().json(ErrorMessage { detail: self.to_string(), }) } @@ -188,18 +188,13 @@ impl From for Error { fn from(sqlx_error: SqlxError) -> Self { match sqlx_error.as_database_error() { Some(db_error) => Error::DatabaseError(db_error.to_string()), - None => { - match sqlx_error { - sqlx::Error::RowNotFound => { - Error::NotFoundError - }, - _ => { - error!("{:?}", sqlx_error); - Error::DatabaseError(format!("Unrecognized database error! {:?}", sqlx_error)) - } + None => match sqlx_error { + sqlx::Error::RowNotFound => Error::NotFoundError, + _ => { + error!("{:?}", sqlx_error); + Error::DatabaseError(format!("Unrecognized database error! {:?}", sqlx_error)) } - - } + }, } } } @@ -216,7 +211,6 @@ impl From for Error { } } - impl From> for Error { fn from(error: PoisonError) -> Self { Error::ConfigError(error.to_string()) @@ -302,13 +296,11 @@ impl From for Error { } impl From for Error { - fn from(err: DecodeError) -> Self - { + fn from(err: DecodeError) -> Self { Error::BincodeError(err.to_string()) } } - impl From for Error { fn from(err: OutOfRangeError) -> Self { Error::ConvertError(err.to_string()) @@ -363,14 +355,12 @@ impl From for Error { } } - impl From>> for Error { fn from(err: UserInfoError>) -> Self { - Error::AuthError(err.to_string()) -} + Error::AuthError(err.to_string()) + } } - impl From for Error { fn from(error: EFIError) -> Self { Error::EFIError(error.to_string()) @@ -384,28 +374,37 @@ impl From for Error { } impl From for Error { - fn from(error: actix_web::Error) -> Self { Error::FrameworkError(error.to_string()) } + fn from(error: actix_web::Error) -> Self { + Error::FrameworkError(error.to_string()) + } } impl From for Error { - fn from(error: data_encoding::DecodeError) -> Self { Error::FrameworkError(error.to_string()) } + fn from(error: data_encoding::DecodeError) -> Self { + Error::FrameworkError(error.to_string()) + } } impl From for Error { - fn from(error: Infallible) -> Self { Error::FrameworkError(error.to_string()) } + fn from(error: Infallible) -> Self { + Error::FrameworkError(error.to_string()) + } } impl From for Error { - fn from(error: TryFromSliceError) -> Self { Error::FrameworkError(error.to_string()) } + fn from(error: TryFromSliceError) -> Self { + Error::FrameworkError(error.to_string()) + } } impl From> for Error { - fn from(error: Vec) -> Self { Error::KeyParseError(format!("original vec {:?}", error)) } + fn from(error: Vec) -> Self { + Error::KeyParseError(format!("original vec {:?}", error)) + } } impl From for Error { - fn from(error: DbErr) -> Self { Error::DatabaseError(error.to_string()) } + fn from(error: DbErr) -> Self { + Error::DatabaseError(error.to_string()) + } } - - - diff --git a/src/util/key.rs b/src/util/key.rs index 2d4eb9d..c087d73 100644 --- a/src/util/key.rs +++ b/src/util/key.rs @@ -14,41 +14,49 @@ * */ +use crate::domain::datakey::entity::Visibility; +use crate::util::error::{Error, Result as LibraryResult}; use hex; -use rand::{thread_rng, Rng}; use rand::distributions::Alphanumeric; +use rand::{thread_rng, Rng}; use serde::{Serialize, Serializer}; -use std::collections::{HashMap, BTreeMap}; -use std::path::Path; -use sha2::{Sha256, Digest}; +use sha2::{Digest, Sha256}; +use std::collections::{BTreeMap, HashMap}; use std::fmt::Write; -use crate::domain::datakey::entity::Visibility; -use crate::util::error::{Error, Result as LibraryResult}; +use std::path::Path; pub fn encode_u8_to_hex_string(value: &[u8]) -> String { - value - .iter() - .fold(String::new(), |mut result, n| { - let _ = write!(result, "{n:02X}"); - result - }) + value.iter().fold(String::new(), |mut result, n| { + let _ = write!(result, "{n:02X}"); + result + }) } -pub fn get_datakey_full_name(name: &str, email: &str, visibility: &Visibility) -> LibraryResult { +pub fn get_datakey_full_name( + name: &str, + email: &str, + visibility: &Visibility, +) -> LibraryResult { let names: Vec<_> = name.split(':').collect(); if visibility.to_owned() == Visibility::Public { return if names.len() <= 1 { Ok(name.to_owned()) } else { - Err(Error::ParameterError("public key name should not contains ':'".to_string())) - } + Err(Error::ParameterError( + "public key name should not contains ':'".to_string(), + )) + }; } if names.len() <= 1 { return Ok(format!("{}:{}", email, name)); } else if names.len() > 2 { - return Err(Error::ParameterError("private key should in the format of {email}:{key_name}".to_string())) + return Err(Error::ParameterError( + "private key should in the format of {email}:{key_name}".to_string(), + )); } else if names[0] != email { - return Err(Error::ParameterError("private key email prefix not matched':'".to_string())) + return Err(Error::ParameterError( + "private key email prefix not matched':'".to_string(), + )); } Ok(name.to_owned()) } @@ -58,7 +66,11 @@ pub fn decode_hex_string_to_u8(value: &String) -> Vec { } pub fn generate_api_token() -> String { - thread_rng().sample_iter(&Alphanumeric).take(40).map(char::from).collect() + thread_rng() + .sample_iter(&Alphanumeric) + .take(40) + .map(char::from) + .collect() } pub fn generate_csrf_parent_token() -> Vec { @@ -84,7 +96,10 @@ pub fn get_token_hash(real_token: &str) -> String { hex::encode(digest) } -pub fn sorted_map(value: &HashMap, serializer: S) -> Result { +pub fn sorted_map( + value: &HashMap, + serializer: S, +) -> Result { let mut items: Vec<(_, _)> = value.iter().collect(); items.sort_by(|a, b| a.0.cmp(b.0)); BTreeMap::from_iter(items).serialize(serializer) @@ -92,10 +107,10 @@ pub fn sorted_map(value: &HashM #[cfg(test)] mod test { + use super::*; use std::env; use std::fs::File; use uuid::Uuid; - use super::*; #[test] fn test_get_datakey_full_name() { @@ -106,12 +121,24 @@ mod test { let name_with_prefix3 = "fake_email@gmail.com:fake3_email@gmail.com:test_key"; let name_without_prefix = "test_key"; //public key - assert_eq!(get_datakey_full_name(name_without_prefix, "fake_email@gmail.com", &public).unwrap(), name_without_prefix.to_string()); - get_datakey_full_name(name_with_prefix, "fake_email@gmail.com", &public).expect_err("public key name should not contains ':'"); - assert_eq!(get_datakey_full_name(name_without_prefix, "fake_email@gmail.com", &private).unwrap(), name_with_prefix.to_string()); - assert_eq!(get_datakey_full_name(name_with_prefix, "fake_email@gmail.com", &private).unwrap(), name_with_prefix.to_string()); - get_datakey_full_name(name_with_prefix2, "fake_email@gmail.com", &private).expect_err("private key email prefix not matched':'"); - get_datakey_full_name(name_with_prefix3, "fake_email@gmail.com", &private).expect_err("private key should in the format of {email}:{key_name}"); + assert_eq!( + get_datakey_full_name(name_without_prefix, "fake_email@gmail.com", &public).unwrap(), + name_without_prefix.to_string() + ); + get_datakey_full_name(name_with_prefix, "fake_email@gmail.com", &public) + .expect_err("public key name should not contains ':'"); + assert_eq!( + get_datakey_full_name(name_without_prefix, "fake_email@gmail.com", &private).unwrap(), + name_with_prefix.to_string() + ); + assert_eq!( + get_datakey_full_name(name_with_prefix, "fake_email@gmail.com", &private).unwrap(), + name_with_prefix.to_string() + ); + get_datakey_full_name(name_with_prefix2, "fake_email@gmail.com", &private) + .expect_err("private key email prefix not matched':'"); + get_datakey_full_name(name_with_prefix3, "fake_email@gmail.com", &private) + .expect_err("private key should in the format of {email}:{key_name}"); } #[test] @@ -148,9 +175,10 @@ mod test { fn test_file_exists() { //generate temp file let valid_path = env::temp_dir().join(Uuid::new_v4().to_string()); - let _valid_file = File::create(valid_path.clone()).expect("create temporary file should work"); + let _valid_file = + File::create(valid_path.clone()).expect("create temporary file should work"); let invalid_path = "./invalid/file/path/should/not/exists"; assert!(file_exists(valid_path.to_str().unwrap())); assert!(!file_exists(invalid_path)); } -} \ No newline at end of file +} diff --git a/src/util/mod.rs b/src/util/mod.rs index 3dc57b4..100185a 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,6 +1,6 @@ +pub mod cache; pub mod config; pub mod error; pub mod key; -pub mod cache; pub mod options; -pub mod sign; \ No newline at end of file +pub mod sign; diff --git a/src/util/options.rs b/src/util/options.rs index 5c3bdae..743b8c3 100644 --- a/src/util/options.rs +++ b/src/util/options.rs @@ -17,4 +17,4 @@ pub const DETACHED: &str = "detached"; pub const KEY_TYPE: &str = "key_type"; pub const SIGN_TYPE: &str = "sign_type"; -pub const RPM_V3_SIGNATURE: &str = "rpm_signature_type"; \ No newline at end of file +pub const RPM_V3_SIGNATURE: &str = "rpm_signature_type"; diff --git a/src/util/sign.rs b/src/util/sign.rs index b99adf4..56b6fa4 100644 --- a/src/util/sign.rs +++ b/src/util/sign.rs @@ -64,7 +64,7 @@ impl Display for KeyType { KeyType::Pgp => write!(f, "pgp"), KeyType::X509EE => write!(f, "x509ee"), //client can use 'x509' to specify a x509 key type for the purpose of simplicity. - KeyType::X509 => write!(f, "x509ee") + KeyType::X509 => write!(f, "x509ee"), } } } -- Gitee