From a9541f5cff26d6a16a00c6f859ee9b20eab92372 Mon Sep 17 00:00:00 2001 From: TommyLike Date: Tue, 25 Jul 2023 14:00:19 +0800 Subject: [PATCH] Check file existence for config option in server&client startup process --- src/client/cmd/add.rs | 47 ++++++++++++++++------- src/client/load_balancer/dns.rs | 25 ++++++++---- src/client/load_balancer/factory.rs | 25 +++++++++--- src/presentation/server/control_server.rs | 28 ++++++-------- src/presentation/server/data_server.rs | 26 +++++-------- src/util/error.rs | 38 ++++++++++-------- src/util/key.rs | 19 +++++++++ 7 files changed, 132 insertions(+), 76 deletions(-) diff --git a/src/client/cmd/add.rs b/src/client/cmd/add.rs index b0ae136..4c60f4c 100644 --- a/src/client/cmd/add.rs +++ b/src/client/cmd/add.rs @@ -37,6 +37,8 @@ use crate::client::worker::signer::RemoteSigner; use crate::client::worker::splitter::Splitter; use crate::client::worker::traits::SignHandler; use std::sync::atomic::{AtomicI32, Ordering}; +use crate::util::error::Error::CommandProcessFailed; +use crate::util::key::file_exists; lazy_static! { pub static ref FILE_EXTENSION: HashMap> = HashMap::from([ @@ -154,6 +156,10 @@ impl SignCommand for CommandAddHandler { if worker_threads == 0 { worker_threads = num_cpus::get(); } + let working_dir = config.read()?.get_string("working_dir")?; + if !file_exists(&working_dir) { + return Err(error::Error::FileFoundError(format!("working dir: {} not exists", working_dir))); + } Ok(CommandAddHandler{ worker_threads, buffer_size: config.read()?.get_string("buffer_size")?.parse()?, @@ -199,10 +205,16 @@ impl SignCommand for CommandAddHandler { let (collect_s, collect_r) = bounded::(self.max_concurrency); info!("starting to sign {} files", files.len()); let lb_config = self.config.read()?.get_table("server")?; - runtime.block_on(async { - let channel = ChannelFactory::new( - &lb_config).await.unwrap().get_channel().unwrap(); - let mut signer = RemoteSigner::new(channel, self.buffer_size); + let errored = runtime.block_on(async { + let channel_provider = ChannelFactory::new(&lb_config).await; + if let Err(err) = channel_provider { + return Some(err) + } + let channel = channel_provider.unwrap().get_channel(); + if let Err(err) = channel { + return Some(err) + } + let mut signer = RemoteSigner::new(channel.unwrap(), self.buffer_size); //split file let send_handlers = files.into_iter().map(|file|{ let task_split_s = split_s.clone(); @@ -294,20 +306,29 @@ impl SignCommand for CommandAddHandler { }); // wait for finish for h in send_handlers { - h.await.unwrap(); + if let Err(error) = h.await { + return Some(CommandProcessFailed(format!("failed to wait for send handler: {}", error.to_string()))) + } + } + for (key, channel, worker) in [ + ("split", split_s, split_handler), + ("sign", sign_s, sign_handler), + ("assemble", assemble_s, assemble_handler), + ("collect", collect_s, collect_handler), + ] { + drop(channel); + if let Err(error) = worker.await { + return Some(CommandProcessFailed(format!("failed to wait for: {0} handler to finish: {1}", key, error.to_string()))) + } } - drop(split_s); - split_handler.await.expect("split worker finished correctly"); - drop(sign_s); - sign_handler.await.expect("sign worker finished correctly"); - drop(assemble_s); - assemble_handler.await.expect("assemble worker finished correctly"); - drop(collect_s); - collect_handler.await.expect("collect worker finished correctly"); info!("Successfully signed {} files failed {} files", succeed_files.load(Ordering::Relaxed), failed_files.load(Ordering::Relaxed)); info!("sign files process finished"); + None }); + if let Some(err) = errored { + return Err(err) + } if failed_files.load(Ordering::Relaxed) != 0 { return Ok(false) } diff --git a/src/client/load_balancer/dns.rs b/src/client/load_balancer/dns.rs index 0492bc1..3fcca98 100644 --- a/src/client/load_balancer/dns.rs +++ b/src/client/load_balancer/dns.rs @@ -23,6 +23,7 @@ use async_trait::async_trait; use dns_lookup::{lookup_host}; +use crate::util::error::Error::{DNSResolveError}; pub struct DNSLoadBalancer { hostname: String, @@ -46,15 +47,23 @@ impl DNSLoadBalancer { impl DynamicLoadBalancer for DNSLoadBalancer { fn get_transport_channel(&self) -> Result { let mut endpoints = Vec::new(); - for ip in lookup_host(&self.hostname)?.into_iter() { - let mut endpoint = Endpoint::from_shared( - format!("http://{}:{}", ip, self.port))?; - if let Some(tls_config) = self.client_config.clone() { - endpoint = endpoint.tls_config(tls_config)?; + match lookup_host(&self.hostname) { + Ok(hosts) => { + for ip in hosts.into_iter() { + let mut endpoint = Endpoint::from_shared( + format!("http://{}:{}", ip, self.port))?; + if let Some(tls_config) = self.client_config.clone() { + endpoint = endpoint.tls_config(tls_config)?; + } + info!("found endpoint {}:{} for signing task.", ip, self.port); + endpoints.push(endpoint); + } + Ok(Channel::balance_list(endpoints.into_iter())) + } + Err(_) => { + Err(DNSResolveError(self.hostname.clone())) } - info!("found endpoint {}:{} for signing task.", ip, self.port); - endpoints.push(endpoint); } - Ok(Channel::balance_list(endpoints.into_iter())) + } } \ No newline at end of file diff --git a/src/client/load_balancer/factory.rs b/src/client/load_balancer/factory.rs index c02f93b..79ed045 100644 --- a/src/client/load_balancer/factory.rs +++ b/src/client/load_balancer/factory.rs @@ -21,6 +21,8 @@ use crate::client::load_balancer::dns::DNSLoadBalancer; use crate::client::load_balancer::single::SingleLoadBalancer; use crate::client::load_balancer::traits::DynamicLoadBalancer; use crate::util::error::{Error, Result}; +use crate::util::error::Error::ConfigError; +use crate::util::key::file_exists; pub struct ChannelFactory { lb: Box @@ -29,15 +31,26 @@ pub struct ChannelFactory { impl ChannelFactory { pub async fn new(config: &HashMap) -> Result { let mut client_config :Option = None; - let tls_cert = config.get("tls_cert").unwrap_or(&Value::new(Some(&String::new()), config::ValueKind::String(String::new()))).to_string(); - let tls_key = config.get("tls_key").unwrap_or(&Value::new(Some(&String::new()), config::ValueKind::String(String::new()))).to_string(); - let server_port = config.get("server_port").expect("server port not in client config").to_string(); + let tls_cert = config.get("tls_cert").unwrap_or( + &Value::new(Some(&String::new()), config::ValueKind::String(String::new()))).to_string(); + let tls_key = config.get("tls_key").unwrap_or( + &Value::new(Some(&String::new()), config::ValueKind::String(String::new()))).to_string(); + let server_address = config.get("server_address").unwrap_or( + &Value::new(Some(&String::new()), config::ValueKind::String(String::new()))).to_string(); + let server_port = config.get("server_port").unwrap_or( + &Value::new(Some(&String::new()), config::ValueKind::String(String::new()))).to_string(); + if server_address.is_empty() || server_port.is_empty() { + return Err(ConfigError(format!("server address: {} or port: {} not configured", server_address, server_port))); + } if tls_cert.is_empty() || tls_key.is_empty() { info!("tls client key and cert not configured, tls will be disabled"); } else { info!("tls client key and cert configured, tls will be enabled"); debug!("tls cert:{}, tls key:{}", tls_cert, tls_key); + if !file_exists(&tls_cert) || !file_exists(&tls_key){ + return Err(Error::FileFoundError(format!("client tls cert {} or key {} file not found", tls_key, tls_cert))); + } let identity = Identity::from_pem( tokio::fs::read(tls_cert).await?, tokio::fs::read(tls_key).await?); @@ -48,17 +61,17 @@ impl ChannelFactory { if lb_type == "single" { return Ok(Self { lb: Box::new(SingleLoadBalancer::new( - config.get("server_address").unwrap_or(&Value::default()).to_string(), + server_address, server_port, client_config)?) }) } else if lb_type == "dns" { return Ok(Self { lb: Box::new(DNSLoadBalancer::new( - config.get("server_address").unwrap_or(&Value::default()).to_string(), + server_address, server_port, client_config)?) }) } - Err(Error::ConfigError(format!("invalid load balancer type configuration {}", lb_type))) + Err(ConfigError(format!("invalid load balancer type configuration: {}", lb_type))) } pub fn get_channel(&self) -> Result { diff --git a/src/presentation/server/control_server.rs b/src/presentation/server/control_server.rs index 1d3e5b8..58a5972 100644 --- a/src/presentation/server/control_server.rs +++ b/src/presentation/server/control_server.rs @@ -22,7 +22,7 @@ use utoipa::{ use utoipa_swagger_ui::SwaggerUi; use std::sync::{Arc, RwLock}; use actix_web::{App, HttpServer, middleware, web, cookie::Key}; -use config::Config; +use config::{Config}; use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; use actix_identity::{IdentityMiddleware}; use actix_session::{config::PersistentSession, storage::RedisSessionStore, SessionMiddleware}; @@ -37,7 +37,7 @@ use actix_web::{dev::ServiceRequest}; use actix_web::cookie::SameSite; use secstr::SecVec; use tokio_util::sync::CancellationToken; -use crate::util::error::Result; +use crate::util::error::{Error, Result}; use crate::application::datakey::{DBKeyService, KeyService}; use crate::infra::database::model::token::repository::TokenRepository; @@ -49,7 +49,7 @@ use crate::domain::token::entity::Token; use crate::domain::user::entity::User; use crate::presentation::handler::control::model::token::dto::{CreateTokenDTO}; use crate::presentation::handler::control::model::user::dto::{UserIdentity}; -use crate::util::key::truncate_string_to_protect_key; +use crate::util::key::{file_exists, truncate_string_to_protect_key}; pub struct ControlServer { server_config: Arc>, @@ -232,24 +232,18 @@ impl ControlServer { .service(web::scope("/api") .service(health_handler::get_scope())) }); - if self.server_config - .read()? - .get_string("tls_cert")? - .is_empty() - || self - .server_config - .read()? - .get_string("tls_key")? - .is_empty() { + let tls_cert = self.server_config.read()?.get_string("tls_cert").unwrap_or(String::new()).to_string(); + let tls_key = self.server_config.read()?.get_string("tls_key").unwrap_or(String::new()).to_string(); + if tls_cert.is_empty() || tls_key.is_empty() { info!("tls key and cert not configured, control server tls will be disabled"); http_server.bind(addr)?.run().await?; } else { + if !file_exists(&tls_cert) || !file_exists(&tls_key) { + return Err(Error::FileFoundError(format!("tls cert: {} or key: {} file not found", tls_key, tls_cert))); + } let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder - .set_private_key_file( - self.server_config.read()?.get_string("tls_key")?, SslFiletype::PEM).unwrap(); - builder.set_certificate_chain_file( - self.server_config.read()?.get_string("tls_cert")?).unwrap(); + builder.set_private_key_file(tls_key, SslFiletype::PEM).unwrap(); + builder.set_certificate_chain_file(tls_cert).unwrap(); http_server.bind_openssl(addr, builder)?.run().await?; } Ok(()) diff --git a/src/presentation/server/data_server.rs b/src/presentation/server/data_server.rs index e99baa7..2727ebe 100644 --- a/src/presentation/server/data_server.rs +++ b/src/presentation/server/data_server.rs @@ -16,7 +16,7 @@ use std::net::SocketAddr; use std::sync::{Arc, RwLock}; -use config::Config; +use config::{Config}; use tokio::fs; use tokio_util::sync::CancellationToken; use tonic::{ @@ -37,7 +37,8 @@ use crate::infra::sign_backend::factory::SignBackendFactory; use crate::presentation::handler::data::sign_handler::get_grpc_handler as sign_grpc_handler; use crate::presentation::handler::data::health_handler::get_grpc_handler as health_grpc_handler; -use crate::util::error::Result; +use crate::util::error::{Error, Result}; +use crate::util::key::file_exists; pub struct DataServer { @@ -62,23 +63,16 @@ impl DataServer { } async fn load(&mut self) -> Result<()> { - if self - .server_config - .read()? - .get_string("tls_cert")? - .is_empty() - || self - .server_config - .read()? - .get_string("tls_key")? - .is_empty() - { + let ca_root = self.server_config.read()?.get("ca_root").unwrap_or(String::new()).to_string(); + let tls_cert = self.server_config.read()?.get("tls_cert").unwrap_or(String::new()).to_string(); + let tls_key = self.server_config.read()?.get("tls_key").unwrap_or(String::new()).to_string(); + if ca_root.is_empty() || tls_cert.is_empty() || tls_key.is_empty() { info!("tls key and cert not configured, data server tls will be disabled"); return Ok(()); } - let ca_root = self.server_config.read()?.get_string("ca_root").expect("ca_root not configured"); - let tls_cert = self.server_config.read()?.get_string("tls_cert")?; - let tls_key = self.server_config.read()?.get_string("tls_key")?; + if !file_exists(&ca_root) || !file_exists(&tls_cert) || !file_exists(&tls_key) { + return Err(Error::FileFoundError(format!("ca root: {} or tls cert: {} or key: {} file not found", ca_root, tls_key, tls_cert))); + } self.ca_cert = Some( Certificate::from_pem(fs::read(ca_root).await?)); self.server_identity = Some(Identity::from_pem(fs::read(tls_cert).await?, diff --git a/src/util/error.rs b/src/util/error.rs index c05b35e..e680140 100644 --- a/src/util/error.rs +++ b/src/util/error.rs @@ -51,38 +51,40 @@ pub type Result = std::result::Result; #[allow(clippy::enum_variant_names)] #[derive(Debug, ThisError, Clone)] pub enum Error { - #[error("An error occurred in database operation. {0}")] + #[error("An error occurred in database operation: {0}")] DatabaseError(String), - #[error("An error occurred when loading configure file. {0}")] + #[error("An error occurred when loading configure: {0}")] ConfigError(String), - #[error("An error occurred when perform IO requests. {0}")] + #[error("An error occurred when perform IO requests: {0}")] IOError(String), - #[error("unsupported type configured. {0}")] + #[error("unsupported type configured: {0}")] UnsupportedTypeError(String), - #[error("kms invoke error. {0}")] + #[error("kms invoke error: {0}")] KMSInvokeError(String), - #[error("failed to serialize/deserialize. {0}")] + #[error("failed to serialize/deserialize: {0}")] SerializeError(String), - #[error("failed to perform http request. {0}")] + #[error("failed to perform http request: {0}")] HttpRequest(String), - #[error("failed to convert. {0}")] + #[error("failed to convert: {0}")] ConvertError(String), - #[error("failed to encode/decode. {0}")] + #[error("failed to encode/decode: {0}")] EncodeError(String), - #[error("failed to get cluster key. {0}")] + #[error("failed to get cluster key: {0}")] ClusterError(String), - #[error("failed to serialize/deserialize key. {0}")] + #[error("failed to serialize/deserialize key: {0}")] KeyParseError(String), - #[error("failed to sign with key {0}. {1}")] + #[error("failed to sign with key {0}: {1}")] SignError(String, String), - #[error("failed to perform pgp {0}")] + #[error("failed to perform pgp: {0}")] PGPInvokeError(String), - #[error("failed to perform openssl {0}")] + #[error("failed to perform openssl: {0}")] X509InvokeError(String), - #[error("invalid parameter error {0}")] + #[error("invalid parameter error: {0}")] ParameterError(String), #[error("record not found error")] NotFoundError, + #[error("fail to load file: {0}")] + FileFoundError(String), #[error("invalid user")] UnauthorizedError, #[error("invalid cookie key found")] @@ -103,7 +105,7 @@ pub enum Error { FrameworkError(String), //client error - #[error("file extension {0} not supported for file {1}")] + #[error("file extension: {0} not supported for file: {1}")] FileNotSupportError(String, String), #[error("not any valid file found")] NoFileCandidateError, @@ -127,6 +129,10 @@ pub enum Error { EFIError(String), #[error("file content is empty")] FileContentEmpty, + #[error("Failed to get IP addresses by hostname: {0}")] + DNSResolveError(String), + #[error("Failed to execute process: {0}")] + CommandProcessFailed(String), } #[derive(Deserialize, Serialize, ToSchema)] diff --git a/src/util/key.rs b/src/util/key.rs index 4ccb9c8..44a2833 100644 --- a/src/util/key.rs +++ b/src/util/key.rs @@ -19,6 +19,7 @@ use rand::{thread_rng, Rng}; use rand::distributions::Alphanumeric; use serde::{Serialize, Serializer}; use std::collections::{HashMap, BTreeMap}; +use std::path::Path; use sha2::{Sha256, Digest}; pub fn encode_u8_to_hex_string(value: &[u8]) -> String { @@ -47,6 +48,11 @@ pub fn truncate_string_to_protect_key(s: &str) -> [u8; 32] { result } +pub fn file_exists(file_path: &str) -> bool { + let path = Path::new(file_path); + path.exists() +} + pub fn get_token_hash(real_token: &str) -> String { let mut hasher = Sha256::default(); hasher.update(real_token); @@ -62,6 +68,9 @@ pub fn sorted_map(value: &HashM #[cfg(test)] mod test { + use std::env; + use std::fs::File; + use uuid::Uuid; use super::*; #[test] @@ -93,4 +102,14 @@ mod test { let content_a = encode_u8_to_hex_string(&decoded); assert_eq!(content, content_a); } + + #[test] + fn test_file_exists() { + //generate temp file + let valid_path = env::temp_dir().join(Uuid::new_v4().to_string()); + let _valid_file = File::create(valid_path.clone()).expect("create temporary file should work"); + let invalid_path = "./invalid/file/path/should/not/exists"; + assert!(file_exists(valid_path.to_str().unwrap())); + assert!(!file_exists(invalid_path)); + } } \ No newline at end of file -- Gitee