From 5c9b867d680817ed37b2238a02e907c15fdf4949 Mon Sep 17 00:00:00 2001 From: tangmengcheng <745274877@qq.com> Date: Thu, 19 Jun 2025 11:27:48 +0800 Subject: [PATCH 1/2] dynolog open source code --- msmonitor/dynolog_npu/cli/Cargo.toml | 23 +- .../dynolog_npu/cli/src/commands/dcgm.rs | 47 +- .../dynolog_npu/cli/src/commands/gputrace.rs | 17 +- msmonitor/dynolog_npu/cli/src/commands/mod.rs | 9 +- .../dynolog_npu/cli/src/commands/status.rs | 23 +- .../dynolog_npu/cli/src/commands/utils.rs | 65 +- .../dynolog_npu/cli/src/commands/version.rs | 20 +- msmonitor/dynolog_npu/cli/src/main.rs | 655 +------------ msmonitor/dynolog_npu/dynolog/src/Main.cpp | 154 ++- msmonitor/dynolog_npu/dynolog/src/Metrics.cpp | 104 +- .../dynolog/src/rpc/SimpleJsonServer.cpp | 899 ++++-------------- .../dynolog/src/rpc/SimpleJsonServer.h | 103 +- .../dynolog/src/tracing/IPCMonitor.cpp | 233 ++--- .../dynolog/src/tracing/IPCMonitor.h | 35 +- 14 files changed, 583 insertions(+), 1804 deletions(-) diff --git a/msmonitor/dynolog_npu/cli/Cargo.toml b/msmonitor/dynolog_npu/cli/Cargo.toml index 5b5917add32..479ec29ce25 100644 --- a/msmonitor/dynolog_npu/cli/Cargo.toml +++ b/msmonitor/dynolog_npu/cli/Cargo.toml @@ -7,27 +7,8 @@ edition = "2021" anyhow = "1.0.57" clap = { version = "3.1.0", features = ["derive"]} serde_json = "1.0" -rustls = "0.21.0" -rustls-pemfile = "1.0" -webpki = "0.22" -x509-parser = "0.15" -der-parser = "8" -pem = "1.1" -chrono = "0.4" -num-bigint = "0.4" -openssl = { version = "0.10", features = ["vendored"] } -rpassword = "7.2.0" +# Make it work with conda +# See https://github.com/rust-lang/cargo/issues/6652 [net] git-fetch-with-cli = true - -[build] -rustflags = [ - "-C", "relocation_model=pie", - "-C", "link-args=-Wl,-z,now", - "-C", "link-args=-Wl,-z,relro", - "-C", "strip=symbols", - "-C", "overflow_checks", - "-C", "link-args=-static-libgcc", - "-C", "link-args=-static-libstdc++" -] diff --git a/msmonitor/dynolog_npu/cli/src/commands/dcgm.rs b/msmonitor/dynolog_npu/cli/src/commands/dcgm.rs index 9f342c04dfd..89fd9ec95c6 100644 --- a/msmonitor/dynolog_npu/cli/src/commands/dcgm.rs +++ b/msmonitor/dynolog_npu/cli/src/commands/dcgm.rs @@ -3,30 +3,43 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. +use std::net::TcpStream; + use anyhow::Result; -use crate::DynoClient; -use super::utils; + +#[path = "utils.rs"] +mod utils; // This module contains the handling logic for dcgm /// Pause dcgm module profiling -pub fn run_dcgm_pause( - mut client: DynoClient, - duration_s: i32, -) -> Result<()> { - let msg = format!(r#"{{"fn":"dcgmPause", "duration_s":{}}}"#, duration_s); - utils::send_msg(&mut client, &msg)?; - let resp_str = utils::get_resp(&mut client)?; - println!("{}", resp_str); +pub fn run_dcgm_pause(client: TcpStream, duration_s: i32) -> Result<()> { + let request_json = format!( + r#" +{{ + "fn": "dcgmProfPause", + "duration_s": {} +}}"#, + duration_s + ); + + utils::send_msg(&client, &request_json).expect("Error sending message to service"); + + let resp_str = utils::get_resp(&client).expect("Unable to decode output bytes"); + + println!("response = {}", resp_str); + Ok(()) } /// Resume dcgm module profiling -pub fn run_dcgm_resume( - mut client: DynoClient, -) -> Result<()> { - utils::send_msg(&mut client, r#"{"fn":"dcgmResume"}"#)?; - let resp_str = utils::get_resp(&mut client)?; - println!("{}", resp_str); +pub fn run_dcgm_resume(client: TcpStream) -> Result<()> { + utils::send_msg(&client, r#"{"fn":"dcgmProfResume"}"#) + .expect("Error sending message to service"); + + let resp_str = utils::get_resp(&client).expect("Unable to decode output bytes"); + + println!("response = {}", resp_str); + Ok(()) -} \ No newline at end of file +} diff --git a/msmonitor/dynolog_npu/cli/src/commands/gputrace.rs b/msmonitor/dynolog_npu/cli/src/commands/gputrace.rs index 677ccf27d10..c2e37af7e31 100644 --- a/msmonitor/dynolog_npu/cli/src/commands/gputrace.rs +++ b/msmonitor/dynolog_npu/cli/src/commands/gputrace.rs @@ -3,10 +3,13 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. +use std::net::TcpStream; + use anyhow::Result; use serde_json::Value; -use crate::DynoClient; -use super::utils; + +#[path = "utils.rs"] +mod utils; // This module contains the handling logic for dyno gputrace @@ -92,7 +95,7 @@ impl GpuTraceConfig { /// Gputrace command triggers GPU profiling on pytorch apps pub fn run_gputrace( - mut client: DynoClient, + client: TcpStream, job_id: u64, pids: &str, process_limit: u32, @@ -114,11 +117,11 @@ pub fn run_gputrace( kineto_config, job_id, pids, process_limit ); - utils::send_msg(&mut client, &request_json)?; + utils::send_msg(&client, &request_json).expect("Error sending message to service"); - let resp_str = utils::get_resp(&mut client)?; + let resp_str = utils::get_resp(&client).expect("Unable to decode output bytes"); - println!("response = {}\n", resp_str); + println!("response = {}", resp_str); let resp_v: Value = serde_json::from_str(&resp_str)?; let processes = resp_v["processesMatched"].as_array().unwrap(); @@ -210,4 +213,4 @@ PROFILE_WITH_FLOPS=false PROFILE_WITH_MODULES=true"# ); } -} \ No newline at end of file +} diff --git a/msmonitor/dynolog_npu/cli/src/commands/mod.rs b/msmonitor/dynolog_npu/cli/src/commands/mod.rs index 1f3bf17ad79..5acfacb640e 100644 --- a/msmonitor/dynolog_npu/cli/src/commands/mod.rs +++ b/msmonitor/dynolog_npu/cli/src/commands/mod.rs @@ -9,11 +9,8 @@ // handling code. Additionally, explicitly "exporting" all the command modules here allows // us to avoid having to explicitly list all the command modules in main.rs. -pub mod status; -pub mod version; pub mod dcgm; pub mod gputrace; -pub mod nputrace; -pub mod npumonitor; -pub mod utils; -// ... add new command modules here \ No newline at end of file +pub mod status; +pub mod version; +// ... add new command modules here diff --git a/msmonitor/dynolog_npu/cli/src/commands/status.rs b/msmonitor/dynolog_npu/cli/src/commands/status.rs index 1be17956c12..c5d9d75d909 100644 --- a/msmonitor/dynolog_npu/cli/src/commands/status.rs +++ b/msmonitor/dynolog_npu/cli/src/commands/status.rs @@ -3,13 +3,22 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. +use std::net::TcpStream; + use anyhow::Result; -use crate::DynoClient; -use super::utils; -pub fn run_status(mut client: DynoClient) -> Result<()> { - utils::send_msg(&mut client, r#"{"fn":"getStatus"}"#)?; - let resp_str = utils::get_resp(&mut client)?; - println!("{}", resp_str); +#[path = "utils.rs"] +mod utils; + +// This module contains the handling logic for dyno status + +/// Get system info +pub fn run_status(client: TcpStream) -> Result<()> { + utils::send_msg(&client, r#"{"fn":"getStatus"}"#).expect("Error sending message to service"); + + let resp_str = utils::get_resp(&client).expect("Unable to decode output bytes"); + + println!("response = {}", resp_str); + Ok(()) -} \ No newline at end of file +} diff --git a/msmonitor/dynolog_npu/cli/src/commands/utils.rs b/msmonitor/dynolog_npu/cli/src/commands/utils.rs index c2fdd3de618..b156c1c812d 100644 --- a/msmonitor/dynolog_npu/cli/src/commands/utils.rs +++ b/msmonitor/dynolog_npu/cli/src/commands/utils.rs @@ -3,48 +3,33 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use std::io::{Read, Write}; +use std::io::Read; +use std::io::Write; +use std::net::TcpStream; use anyhow::Result; -use crate::DynoClient; - -pub fn send_msg(client: &mut DynoClient, msg: &str) -> Result<()> { - match client { - DynoClient::Secure(secure_client) => { - let msg_len: [u8; 4] = i32::try_from(msg.len()).unwrap().to_ne_bytes(); - secure_client.write_all(&msg_len)?; - secure_client.write_all(msg.as_bytes())?; - secure_client.flush()?; - } - DynoClient::Insecure(insecure_client) => { - let msg_len: [u8; 4] = i32::try_from(msg.len()).unwrap().to_ne_bytes(); - insecure_client.write_all(&msg_len)?; - insecure_client.write_all(msg.as_bytes())?; - insecure_client.flush()?; - } - } - Ok(()) +pub fn send_msg(mut client: &TcpStream, msg: &str) -> Result<()> { + let msg_len: [u8; 4] = i32::try_from(msg.len()).unwrap().to_ne_bytes(); + + client.write_all(&msg_len)?; + client.write_all(msg.as_bytes()).map_err(|err| err.into()) } -pub fn get_resp(client: &mut DynoClient) -> Result { - let mut len_buf = [0u8; 4]; - let mut resp_buf; - - match client { - DynoClient::Secure(secure_client) => { - secure_client.read_exact(&mut len_buf)?; - let len = u32::from_ne_bytes(len_buf) as usize; - resp_buf = vec![0u8; len]; - secure_client.read_exact(&mut resp_buf)?; - } - DynoClient::Insecure(insecure_client) => { - insecure_client.read_exact(&mut len_buf)?; - let len = u32::from_ne_bytes(len_buf) as usize; - resp_buf = vec![0u8; len]; - insecure_client.read_exact(&mut resp_buf)?; - } - } - - Ok(String::from_utf8(resp_buf)?) -} \ No newline at end of file +pub fn get_resp(mut client: &TcpStream) -> Result { + // Response is prefixed with length + let mut resp_len: [u8; 4] = [0; 4]; + client.read_exact(&mut resp_len)?; + + let resp_len = i32::from_ne_bytes(resp_len); + let resp_len = usize::try_from(resp_len).unwrap(); + + println!("response length = {}", resp_len); + + let mut resp_str = Vec::::new(); + resp_str.resize(resp_len, 0); + + client.read_exact(resp_str.as_mut_slice())?; + + String::from_utf8(resp_str).map_err(|err| err.into()) +} diff --git a/msmonitor/dynolog_npu/cli/src/commands/version.rs b/msmonitor/dynolog_npu/cli/src/commands/version.rs index 31139d56852..94b5129f47c 100644 --- a/msmonitor/dynolog_npu/cli/src/commands/version.rs +++ b/msmonitor/dynolog_npu/cli/src/commands/version.rs @@ -3,16 +3,22 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. +use std::net::TcpStream; + use anyhow::Result; -use crate::DynoClient; -use super::utils; + +#[path = "utils.rs"] +mod utils; // This module contains the handling logic for querying dyno version /// Get version info -pub fn run_version(mut client: DynoClient) -> Result<()> { - utils::send_msg(&mut client, r#"{"fn":"getVersion"}"#)?; - let resp_str = utils::get_resp(&mut client)?; - println!("{}", resp_str); +pub fn run_version(client: TcpStream) -> Result<()> { + utils::send_msg(&client, r#"{"fn":"getVersion"}"#).expect("Error sending message to service"); + + let resp_str = utils::get_resp(&client).expect("Unable to decode output bytes"); + + println!("response = {}", resp_str); + Ok(()) -} \ No newline at end of file +} diff --git a/msmonitor/dynolog_npu/cli/src/main.rs b/msmonitor/dynolog_npu/cli/src/main.rs index 2bd85a79637..b7c755e61ba 100644 --- a/msmonitor/dynolog_npu/cli/src/main.rs +++ b/msmonitor/dynolog_npu/cli/src/main.rs @@ -2,38 +2,18 @@ // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use std::fs::File; -use std::io::BufReader; -use rustls::{Certificate, RootCertStore, PrivateKey, ClientConnection, StreamOwned}; -use std::sync::Arc; + use std::net::TcpStream; use std::net::ToSocketAddrs; -use std::path::PathBuf; -use std::io; -use rpassword::prompt_password; use anyhow::Result; use clap::Parser; -use std::collections::HashSet; - -use x509_parser::prelude::*; -use x509_parser::num_bigint::ToBigInt; -use std::fs::read_to_string; -use x509_parser::public_key::RSAPublicKey; -use x509_parser::der_parser::oid; -use num_bigint::BigUint; -use openssl::pkey::PKey; -use std::io::Read; // Make all the command modules accessible to this file. mod commands; use commands::gputrace::GpuTraceConfig; use commands::gputrace::GpuTraceOptions; use commands::gputrace::GpuTraceTriggerConfig; -use commands::nputrace::NpuTraceConfig; -use commands::nputrace::NpuTraceOptions; -use commands::nputrace::NpuTraceTriggerConfig; -use commands::npumonitor::NpuMonitorConfig; use commands::*; /// Instructions on adding a new Dyno CLI command: @@ -52,7 +32,6 @@ use commands::*; /// the command dispatching logic clear and concise, please keep the code in the match branch to a minimum. const DYNO_PORT: u16 = 1778; -const MIN_RSA_KEY_LENGTH: u64 = 3072; // 最小 RSA 密钥长度(位) #[derive(Debug, Parser)] struct Opts { @@ -60,49 +39,10 @@ struct Opts { hostname: String, #[clap(long, default_value_t = DYNO_PORT)] port: u16, - #[clap(long, required = true)] - certs_dir: String, #[clap(subcommand)] cmd: Command, } -const ALLOWED_VALUES: &[&str] = &["Marker", "Kernel", "API", "Hccl", "Memory", "MemSet", "MemCpy"]; - -fn parse_mspti_activity_kinds(src: &str) -> Result{ - let allowed_values: HashSet<&str> = ALLOWED_VALUES.iter().cloned().collect(); - - let kinds: Vec<&str> = src.split(',').map(|s| s.trim()).collect(); - - for kind in &kinds { - if !allowed_values.contains(kind) { - return Err(format!("Invalid MSPTI activity kind: {}, Possible values: {:?}.]", kind, allowed_values)); - } - } - - Ok(src.to_string()) -} - -const ALLOWED_HOST_SYSTEM_VALUES: &[&str] = &["cpu", "mem", "disk", "network", "osrt"]; - -fn parse_host_sys(src: &str) -> Result{ - if src == "None" { - return Ok(src.to_string()); - } - - let allowed_host_sys_values: HashSet<&str> = ALLOWED_HOST_SYSTEM_VALUES.iter().cloned().collect(); - - let host_systems: Vec<&str> = src.split(',').map(|s| s.trim()).collect(); - - for host_system in &host_systems { - if !allowed_host_sys_values.contains(host_system) { - return Err(format!("Invalid NPU Trace host system: {}, Possible values: {:?}.]", host_system, - allowed_host_sys_values)); - } - } - let result = host_systems.join(","); - Ok(result) -} - #[derive(Debug, Parser)] enum Command { /// Check the status of a dynolog process @@ -111,7 +51,7 @@ enum Command { Version, /// Capture gputrace Gputrace { - /// Job id of the application to trace. + /// Job id of the application to trace #[clap(long, default_value_t = 0)] job_id: u64, /// List of pids to capture trace for (comma separated). @@ -126,134 +66,32 @@ enum Command { /// Log file for trace. #[clap(long)] log_file: String, - /// Unix timestamp used for synchronized collection (milliseconds since epoch). + /// Unix timestamp used for synchronized collection (milliseconds since epoch) #[clap(long, default_value_t = 0)] profile_start_time: u64, /// Start iteration roundup, starts an iteration based trace at a multiple /// of this value. #[clap(long, default_value_t = 1)] profile_start_iteration_roundup: u64, - /// Max number of processes to profile. + /// Max number of processes to profile #[clap(long, default_value_t = 3)] process_limit: u32, - /// Record PyTorch operator input shapes and types. + /// Record PyTorch operator input shapes and types #[clap(long, action)] record_shapes: bool, - /// Profile PyTorch memory. + /// Profile PyTorch memory #[clap(long, action)] profile_memory: bool, - /// Capture Python stacks in traces. + /// Capture Python stacks in traces #[clap(long, action)] with_stacks: bool, - /// Annotate operators with analytical flops. + /// Annotate operators with analytical flops #[clap(long, action)] with_flops: bool, - /// Capture PyTorch operator modules in traces. + /// Capture PyTorch operator modules in traces #[clap(long, action)] with_modules: bool, }, - /// Capture nputrace. Subcommand functions aligned with Ascend Torch Profiler. - Nputrace { - /// Job id of the application to trace. - #[clap(long, default_value_t = 0)] - job_id: u64, - /// List of pids to capture trace for (comma separated). - #[clap(long, default_value = "0")] - pids: String, - /// Duration of trace to collect in ms. - #[clap(long, default_value_t = 500)] - duration_ms: u64, - /// Training iterations to collect, this takes precedence over duration. - #[clap(long, default_value_t = -1)] - iterations: i64, - /// Log file for trace. - #[clap(long)] - log_file: String, - /// Unix timestamp used for synchronized collection (milliseconds since epoch). - #[clap(long, default_value_t = 0)] - profile_start_time: u64, - /// Number of steps to start profile. - #[clap(long, default_value_t = 0)] - start_step: u64, - /// Max number of processes to profile. - #[clap(long, default_value_t = 3)] - process_limit: u32, - /// Whether to record PyTorch operator input shapes and types. - #[clap(long, action)] - record_shapes: bool, - /// Whether to profile PyTorch memory. - #[clap(long, action)] - profile_memory: bool, - /// Whether to profile the Python call stack in trace. - #[clap(long, action)] - with_stack: bool, - /// Annotate operators with analytical flops. - #[clap(long, action)] - with_flops: bool, - /// Whether to profile PyTorch operator modules in traces. - #[clap(long, action)] - with_modules: bool, - /// The scope of the profile's events. - #[clap(long, value_parser = ["CPU,NPU", "NPU,CPU", "CPU", "NPU"], default_value = "CPU,NPU")] - activities: String, - /// Profiler level. - #[clap(long, value_parser = ["Level0", "Level1", "Level2", "Level_none"], default_value = "Level0")] - profiler_level: String, - /// AIC metrics. - #[clap(long, value_parser = ["AiCoreNone", "PipeUtilization", "ArithmeticUtilization", "Memory", "MemoryL0", "ResourceConflictRatio", "MemoryUB", "L2Cache", "MemoryAccess"], default_value = "AiCoreNone")] - aic_metrics: String, - /// Whether to analyse the data after collection. - #[clap(long, action)] - analyse: bool, - /// Whether to collect L2 cache. - #[clap(long, action)] - l2_cache: bool, - /// Whether to collect op attributes. - #[clap(long, action)] - op_attr: bool, - /// Whether to enable MSTX. - #[clap(long, action)] - msprof_tx: bool, - /// GC detect threshold. - #[clap(long)] - gc_detect_threshold: Option, - /// Whether to streamline data after analyse is complete. - #[clap(long, value_parser = ["true", "false"], default_value = "true")] - data_simplification: String, - /// Types of data exported by the profiler. - #[clap(long, value_parser = ["Text", "Db"], default_value = "Text")] - export_type: String, - /// Obtain the system data on the host side. - #[clap(long, value_parser = parse_host_sys, default_value = "None")] - host_sys: String, - /// Whether to enable sys io. - #[clap(long, action)] - sys_io: bool, - /// Whether to enable sys interconnection. - #[clap(long, action)] - sys_interconnection: bool, - /// The domain that needs to be enabled in mstx mode. - #[clap(long)] - mstx_domain_include: Option, - /// Domains that do not need to be enabled in mstx mode. - #[clap(long)] - mstx_domain_exclude: Option, - }, - /// Ascend MSPTI Monitor - NpuMonitor { - /// Start NPU monitor. - #[clap(long, action)] - npu_monitor_start: bool, - /// Stop NPU monitor. - #[clap(long, action)] - npu_monitor_stop: bool, - /// NPU monitor report interval in seconds. - #[clap(long, default_value_t = 60)] - report_interval_s: u32, - /// MSPTI collect activity kind - #[clap(long, value_parser = parse_mspti_activity_kinds, default_value = "Marker")] - mspti_activity_kind: String, - }, /// Pause dcgm profiling. This enables running tools like Nsight compute and avoids conflicts. DcgmPause { /// Duration to pause dcgm profiling in seconds @@ -264,397 +102,29 @@ enum Command { DcgmResume, } -struct ClientConfigPath { - cert_path: PathBuf, - key_path: PathBuf, - ca_cert_path: PathBuf, -} - -fn verify_certificate(cert_der: &[u8], is_root_cert: bool) -> Result<()> { - // 解析 X509 证书 - let (_, cert) = X509Certificate::from_der(cert_der) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse cert: {:?}", e)))?; - - // 检查证书版本是否为 X.509v3 - if cert.tbs_certificate.version != X509Version(2) { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Certificate is not X.509v3" - ).into()); - } - - // 检查证书签名算法 - let sig_alg = cert.signature_algorithm.algorithm; - - // 定义不安全的算法 OID - let md2_rsa = oid!(1.2.840.113549.1.1.2); // MD2 with RSA - let md5_rsa = oid!(1.2.840.113549.1.1.4); // MD5 with RSA - let sha1_rsa = oid!(1.2.840.113549.1.1.5); // SHA1 with RSA - - // 检查是否使用不安全的算法 - if sig_alg == md2_rsa || sig_alg == md5_rsa || sig_alg == sha1_rsa { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Certificate uses insecure signature algorithm" - ).into()); - } - - // 定义 RSA 签名算法 OID - let rsa_sha256 = oid!(1.2.840.113549.1.1.11); // RSA with SHA256 - let rsa_sha384 = oid!(1.2.840.113549.1.1.12); // RSA with SHA384 - let rsa_sha512 = oid!(1.2.840.113549.1.1.13); // RSA with SHA512 - - // 检查 RSA 密钥长度 - if sig_alg == rsa_sha256 || sig_alg == rsa_sha384 || sig_alg == rsa_sha512 { - // 获取公钥 - if let Ok((_, public_key)) = SubjectPublicKeyInfo::from_der(&cert.tbs_certificate.subject_pki.subject_public_key.data) { - if let Ok((_, rsa_key)) = RSAPublicKey::from_der(&public_key.subject_public_key.data) { - // 检查 RSA 密钥长度 - let modulus = BigUint::from_bytes_be(&rsa_key.modulus); - let key_length = modulus.bits(); - if key_length < MIN_RSA_KEY_LENGTH { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("RSA key length {} bits is less than required {} bits", key_length, MIN_RSA_KEY_LENGTH) - ).into()); - } - } - } - } - - // 检查证书的扩展域 - let mut has_ca_constraint = false; - let mut has_key_usage = false; - let mut has_crl_sign = false; - let mut has_cert_sign = false; - - for ext in cert.tbs_certificate.extensions() { - if ext.oid == oid_registry::OID_X509_EXT_BASIC_CONSTRAINTS { - if let Ok((_, constraints)) = BasicConstraints::from_der(ext.value) { - has_ca_constraint = constraints.ca; - } else { - println!("Failed to parse Basic Constraints"); - } - } else if ext.oid == oid_registry::OID_X509_EXT_KEY_USAGE { - println!("Found Key Usage extension"); - if let Ok((_, usage)) = KeyUsage::from_der(ext.value) { - has_key_usage = true; - has_cert_sign = usage.key_cert_sign(); - has_crl_sign = usage.crl_sign(); - } else { - println!("Failed to parse Key Usage"); - } - } - } - - // 根据证书类型进行不同的验证 - if is_root_cert { - // 根证书验证要求 - if !has_ca_constraint { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Root certificate must have CA constraint" - ).into()); - } - if !has_key_usage { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Root certificate must have key usage extension" - ).into()); - } - if !has_cert_sign { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Root certificate must have certificate signature permission" - ).into()); - } - if !has_crl_sign { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Root certificate must have CRL signature permission" - ).into()); - } - } else { - // 客户端证书验证要求 - if has_ca_constraint { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Client certificate should not have CA constraint" - ).into()); - } - if !has_key_usage { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Client certificate must have key usage extension" - ).into()); - } - } - - // 检查证书有效期 - let now = chrono::Utc::now(); - let not_before = chrono::DateTime::from_timestamp( - cert.tbs_certificate.validity.not_before.timestamp(), - 0 - ).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid not_before date"))?; - - let not_after = chrono::DateTime::from_timestamp( - cert.tbs_certificate.validity.not_after.timestamp(), - 0 - ).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid not_after date"))?; - - if now < not_before { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Certificate is not yet valid. Valid from: {}", not_before) - ).into()); - } - - if now > not_after { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Certificate has expired. Expired at: {}", not_after) - ).into()); - } - - Ok(()) -} - -fn is_cert_revoked(cert_der: &[u8], crl_path: &PathBuf) -> Result { - // 解析 X509 证书 - let (_, cert) = X509Certificate::from_der(cert_der) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse cert: {:?}", e)))?; - - // 读取 CRL 文件 - let crl_data = read_to_string(crl_path)?; - let (_, pem) = pem::parse_x509_pem(crl_data.as_bytes()) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse CRL PEM: {:?}", e)))?; - - // 解析 CRL - let (_, crl) = CertificateRevocationList::from_der(&pem.contents) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse CRL: {:?}", e)))?; - - // 检查 CRL 的有效期 - let now = chrono::Utc::now(); - let crl_not_before = chrono::DateTime::from_timestamp( - crl.tbs_cert_list.this_update.timestamp(), - 0 - ).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid CRL this_update date"))?; - - let crl_not_after = if let Some(next_update) = crl.tbs_cert_list.next_update { - chrono::DateTime::from_timestamp( - next_update.timestamp(), - 0 - ).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid CRL next_update date"))? - } else { - crl_not_before + chrono::Duration::days(365) - }; - - // 检查 CRL 是否在有效期内 - if now < crl_not_before { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("CRL is not yet valid. Valid from: {}", crl_not_before) - ).into()); - } - - if now > crl_not_after { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("CRL has expired. Expired at: {}", crl_not_after) - ).into()); - } - - // 获取证书序列号 - let cert_serial = cert.tbs_certificate.serial.to_bigint() - .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Failed to convert certificate serial to BigInt"))?; - - // 检查 CRL 吊销条目 - for revoked in crl.iter_revoked_certificates() { - let revoked_serial = revoked.user_certificate.to_bigint() - .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Failed to convert revoked certificate serial to BigInt"))?; - - if revoked_serial == cert_serial { - return Ok(true); - } - } - Ok(false) -} - -enum DynoClient { - Secure(StreamOwned), - Insecure(TcpStream), -} - -fn create_dyno_client( - host: &str, - port: u16, - certs_dir: &str, -) -> Result { - if certs_dir == "NO_CERTS" { - println!("Running in no-certificate mode"); - create_dyno_client_with_no_certs(host, port) - } else { - println!("Running in certificate mode"); - let certs_dir = PathBuf::from(certs_dir); - let config = ClientConfigPath { - cert_path: certs_dir.join("client.crt"), - key_path: certs_dir.join("client.key"), - ca_cert_path: certs_dir.join("ca.crt"), - }; - let client = create_dyno_client_with_certs(host, port, &config)?; - Ok(DynoClient::Secure(client)) - } -} - -fn create_dyno_client_with_no_certs( - host: &str, - port: u16, -) -> Result { +/// Create a socket connection to dynolog +fn create_dyno_client(host: &str, port: u16) -> Result { let addr = (host, port) .to_socket_addrs()? .next() .expect("Failed to connect to the server"); - let stream = TcpStream::connect(addr)?; - Ok(DynoClient::Insecure(stream)) -} - -fn create_dyno_client_with_certs( - host: &str, - port: u16, - config: &ClientConfigPath, -) -> Result> { - let addr = (host, port) - .to_socket_addrs()? - .next() - .ok_or_else(|| io::Error::new( - io::ErrorKind::NotFound, - "Could not resolve the host address" - ))?; - - let stream = TcpStream::connect(addr)?; - - println!("Loading CA cert from: {}", config.ca_cert_path.display()); - let mut root_store = RootCertStore::empty(); - let ca_file = File::open(&config.ca_cert_path)?; - let mut ca_reader = BufReader::new(ca_file); - let ca_certs = rustls_pemfile::certs(&mut ca_reader)?; - for ca_cert in &ca_certs { - verify_certificate(ca_cert, true)?; // 验证根证书 - } - for ca_cert in ca_certs { - root_store.add(&Certificate(ca_cert))?; - } - - println!("Loading client cert from: {}", config.cert_path.display()); - let cert_file = File::open(&config.cert_path)?; - let mut cert_reader = BufReader::new(cert_file); - let certs = rustls_pemfile::certs(&mut cert_reader)?; - - // 检查客户端证书的基本要求 - for cert in &certs { - verify_certificate(cert, false)?; // 验证客户端证书 - } - - // 检查证书吊销状态 - let crl_path = config.cert_path.parent().unwrap().join("ca.crl"); - if crl_path.exists() { - println!("Checking CRL file: {}", crl_path.display()); - for cert in &certs { - match is_cert_revoked(cert, &crl_path) { - Ok(true) => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Certificate is revoked" - ).into()); - } - Ok(false) => { - continue; - } - Err(e) => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("CRL verification failed: {}", e) - ).into()); - } - } - } - } else { - println!("CRL file does not exist: {}", crl_path.display()); - } - let certs = certs.into_iter().map(Certificate).collect(); - - println!("Loading client key from: {}", config.key_path.display()); - let key_file = File::open(&config.key_path)?; - let mut key_reader = BufReader::new(key_file); - - // 检查私钥是否加密 - let mut key_data = Vec::new(); - key_reader.read_to_end(&mut key_data)?; - let key_str = String::from_utf8_lossy(&key_data); - let is_encrypted = key_str.contains("ENCRYPTED"); - - // 根据是否加密来加载私钥 - let keys = if is_encrypted { - // 如果私钥是加密的,请求用户输入密码 - let mut password = prompt_password("Please enter the certificate password: ")?; - let pkey = PKey::private_key_from_pem_passphrase(&key_data, password.as_bytes()) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to decrypt private key: {}", e)))?; - - // 清除密码 - password.clear(); - - // 返回私钥 - vec![pkey.private_key_to_der() - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to convert private key to DER: {}", e)))?] - } else { - // 如果私钥未加密,直接加载 - let mut key_reader = BufReader::new(File::open(&config.key_path)?); - rustls_pemfile::pkcs8_private_keys(&mut key_reader)? - }; - - if keys.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "No private key found in the key file" - ).into()); - } - let key = PrivateKey(keys[0].clone()); - - let config = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_client_auth_cert(certs, key)?; - - let server_name = rustls::ServerName::try_from(host) - .map_err(|e| io::Error::new( - io::ErrorKind::InvalidInput, - format!("Invalid hostname: {}", e) - ))?; - - let conn = rustls::ClientConnection::new( - Arc::new(config), - server_name - )?; - - Ok(StreamOwned::new(conn, stream)) + TcpStream::connect(addr).map_err(|err| err.into()) } - fn main() -> Result<()> { let Opts { hostname, port, - certs_dir, cmd, } = Opts::parse(); - let client = create_dyno_client(&hostname, port, &certs_dir) - .expect("Couldn't connect to the server..."); + let dyno_client = + create_dyno_client(&hostname, port).expect("Couldn't connect to the server..."); match cmd { - Command::Status => status::run_status(client), - Command::Version => version::run_version(client), + Command::Status => status::run_status(dyno_client), + Command::Version => version::run_version(dyno_client), Command::Gputrace { job_id, pids, @@ -693,95 +163,10 @@ fn main() -> Result<()> { trigger_config, trace_options, }; - gputrace::run_gputrace(client, job_id, &pids, process_limit, trace_config) + gputrace::run_gputrace(dyno_client, job_id, &pids, process_limit, trace_config) } - Command::Nputrace { - job_id, - pids, - log_file, - duration_ms, - iterations, - profile_start_time, - start_step, - process_limit, - record_shapes, - profile_memory, - with_stack, - with_flops, - with_modules, - activities, - analyse, - profiler_level, - aic_metrics, - l2_cache, - op_attr, - msprof_tx, - gc_detect_threshold, - data_simplification, - export_type, - host_sys, - sys_io, - sys_interconnection, - mstx_domain_include, - mstx_domain_exclude, - } => { - let trigger_config = if iterations > 0 { - NpuTraceTriggerConfig::IterationBased { - start_step, - iterations, - } - } else { - NpuTraceTriggerConfig::DurationBased { - profile_start_time, - duration_ms, - } - }; - - let trace_options = NpuTraceOptions { - record_shapes, - profile_memory, - with_stack, - with_flops, - with_modules, - activities, - analyse, - profiler_level, - aic_metrics, - l2_cache, - op_attr, - msprof_tx, - gc_detect_threshold, - data_simplification, - export_type, - host_sys, - sys_io, - sys_interconnection, - mstx_domain_include, - mstx_domain_exclude, - }; - let trace_config = NpuTraceConfig { - log_file, - trigger_config, - trace_options, - }; - nputrace::run_nputrace(client, job_id, &pids, process_limit, trace_config) - } - Command::NpuMonitor { - npu_monitor_start, - npu_monitor_stop, - report_interval_s, - mspti_activity_kind, - } => { - let npu_mon_config = NpuMonitorConfig { - npu_monitor_start, - npu_monitor_stop, - report_interval_s, - mspti_activity_kind - }; - npumonitor::run_npumonitor(client, npu_mon_config) - } - Command::DcgmPause { duration_s } => dcgm::run_dcgm_pause(client, duration_s), - Command::DcgmResume => dcgm::run_dcgm_resume(client), + Command::DcgmPause { duration_s } => dcgm::run_dcgm_pause(dyno_client, duration_s), + Command::DcgmResume => dcgm::run_dcgm_resume(dyno_client), // ... add new commands here } -} \ No newline at end of file +} diff --git a/msmonitor/dynolog_npu/dynolog/src/Main.cpp b/msmonitor/dynolog_npu/dynolog/src/Main.cpp index a24bb802ab4..41f224fc905 100644 --- a/msmonitor/dynolog_npu/dynolog/src/Main.cpp +++ b/msmonitor/dynolog_npu/dynolog/src/Main.cpp @@ -15,7 +15,6 @@ #include "dynolog/src/KernelCollector.h" #include "dynolog/src/Logger.h" #include "dynolog/src/ODSJsonLogger.h" - #include "dynolog/src/PerfMonitor.h" #include "dynolog/src/ScubaLogger.h" #include "dynolog/src/ServiceHandler.h" @@ -29,10 +28,6 @@ #include "dynolog/src/PrometheusLogger.h" #endif -#ifdef USE_TENSORBOARD -#include "dynolog/src/DynologTensorBoardLogger.h" -#endif - using namespace dynolog; using json = nlohmann::json; namespace hbt = facebook::hbt; @@ -67,47 +62,39 @@ DEFINE_bool( "Enabled GPU monitorng, currently supports NVIDIA GPUs."); DEFINE_bool(enable_perf_monitor, false, "Enable heartbeat perf monitoring."); -std::unique_ptr getLogger(const std::string& scribe_category = "") -{ - std::vector> loggers; +std::unique_ptr getLogger(const std::string& scribe_category = "") { + std::vector> loggers; #ifdef USE_PROMETHEUS - if (FLAGS_use_prometheus) { + if (FLAGS_use_prometheus) { loggers.push_back(std::make_unique()); - } + } #endif -#ifdef USE_TENSORBOARD - if (!FLAGS_metric_log_dir.empty()) { - loggers.push_back(std::make_unique(FLAGS_metric_log_dir)); - } -#endif - if (FLAGS_use_fbrelay) { + if (FLAGS_use_fbrelay) { loggers.push_back(std::make_unique()); - } - if (FLAGS_use_ODS) { + } + if (FLAGS_use_ODS) { loggers.push_back(std::make_unique()); - } - if (FLAGS_use_JSON) { + } + if (FLAGS_use_JSON) { loggers.push_back(std::make_unique()); - } - if (FLAGS_use_scuba && !scribe_category.empty()) { + } + if (FLAGS_use_scuba && !scribe_category.empty()) { loggers.push_back(std::make_unique(scribe_category)); - } - return std::make_unique(std::move(loggers)); + } + return std::make_unique(std::move(loggers)); } -auto next_wakeup(int sec) -{ - return std::chrono::steady_clock::now() + std::chrono::seconds(sec); +auto next_wakeup(int sec) { + return std::chrono::steady_clock::now() + std::chrono::seconds(sec); } -void kernel_monitor_loop() -{ - KernelCollector kc; +void kernel_monitor_loop() { + KernelCollector kc; - LOG(INFO) << "Running kernel monitor loop : interval = " + LOG(INFO) << "Running kernel monitor loop : interval = " << FLAGS_kernel_monitor_reporting_interval_s << " s."; - while (1) { + while (1) { auto logger = getLogger(); auto wakeup_timepoint = next_wakeup(FLAGS_kernel_monitor_reporting_interval_s); @@ -118,21 +105,20 @@ void kernel_monitor_loop() /* sleep override */ std::this_thread::sleep_until(wakeup_timepoint); - } + } } -void perf_monitor_loop() -{ - PerfMonitor pm( +void perf_monitor_loop() { + PerfMonitor pm( hbt::CpuSet::makeAllOnline(), std::vector{"instructions", "cycles"}, getDefaultPmuDeviceManager(), getDefaultMetrics()); - LOG(INFO) << "Running perf monitor loop : interval = " - << FLAGS_perf_monitor_reporting_interval_s << " s."; + LOG(INFO) << "Running perf monitor loop : interval = " + << FLAGS_perf_monitor_reporting_interval_s << " s."; - while (1) { + while (1) { auto logger = getLogger(); auto wakeup_timepoint = next_wakeup(FLAGS_perf_monitor_reporting_interval_s); @@ -143,24 +129,22 @@ void perf_monitor_loop() logger->finalize(); /* sleep override */ std::this_thread::sleep_until(wakeup_timepoint); - } + } } -auto setup_server(std::shared_ptr handler) -{ - return std::make_unique>( - handler, FLAGS_port); +auto setup_server(std::shared_ptr handler) { + return std::make_unique>( + handler, FLAGS_port); } -void gpu_monitor_loop(std::shared_ptr dcgm) -{ - auto logger = getLogger(FLAGS_scribe_category); +void gpu_monitor_loop(std::shared_ptr dcgm) { + auto logger = getLogger(FLAGS_scribe_category); - LOG(INFO) << "Running DCGM loop : interval = " - << FLAGS_dcgm_reporting_interval_s << " s."; - LOG(INFO) << "DCGM fields: " << gpumon::FLAGS_dcgm_fields; + LOG(INFO) << "Running DCGM loop : interval = " + << FLAGS_dcgm_reporting_interval_s << " s."; + LOG(INFO) << "DCGM fields: " << gpumon::FLAGS_dcgm_fields; - while (1) { + while (1) { auto wakeup_timepoint = next_wakeup(FLAGS_dcgm_reporting_interval_s); dcgm->update(); @@ -168,65 +152,55 @@ void gpu_monitor_loop(std::shared_ptr dcgm) /* sleep override */ std::this_thread::sleep_until(wakeup_timepoint); - } + } } -int main(int argc, char** argv) -{ - gflags::ParseCommandLineFlags(&argc, &argv, true); - FLAGS_logtostderr = 1; - google::InitGoogleLogging(argv[0]); +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + FLAGS_logtostderr = 1; + google::InitGoogleLogging(argv[0]); - LOG(INFO) << "Starting Ascend Extension for dynolog, version = " DYNOLOG_VERSION - << ", build git-hash = " DYNOLOG_GIT_REV; + LOG(INFO) << "Starting dynolog, version = " DYNOLOG_VERSION + << ", build git-hash = " DYNOLOG_GIT_REV; - std::shared_ptr dcgm; + std::shared_ptr dcgm; - std::unique_ptr ipcmon; - std::unique_ptr ipcmon_thread; - std::unique_ptr data_ipcmon_thread; - std::unique_ptr gpumon_thread; - std::unique_ptr pm_thread; + std::unique_ptr ipcmon; + std::unique_ptr ipcmon_thread, gpumon_thread, pm_thread; - if (FLAGS_enable_ipc_monitor) { + if (FLAGS_enable_ipc_monitor) { LOG(INFO) << "Starting IPC Monitor"; ipcmon = std::make_unique(); - ipcmon->setLogger(std::move(getLogger())); ipcmon_thread = std::make_unique([&ipcmon]() { ipcmon->loop(); }); - data_ipcmon_thread = - std::make_unique([&ipcmon]() { ipcmon->dataLoop(); }); - } + } - if (FLAGS_enable_gpu_monitor) { + if (FLAGS_enable_gpu_monitor) { dcgm = gpumon::DcgmGroupInfo::factory( gpumon::FLAGS_dcgm_fields, FLAGS_dcgm_reporting_interval_s * 1000); gpumon_thread = std::make_unique(gpu_monitor_loop, dcgm); - } - std::thread km_thread{kernel_monitor_loop}; - if (FLAGS_enable_perf_monitor) { + } + std::thread km_thread{kernel_monitor_loop}; + if (FLAGS_enable_perf_monitor) { pm_thread = std::make_unique(perf_monitor_loop); - } - - // setup service - auto handler = std::make_shared(dcgm); + } - // use simple json RPC server for now - auto server = setup_server(handler); - server->run(); + // setup service + auto handler = std::make_shared(dcgm); - if (km_thread.joinable()) { - km_thread.join(); - } + // use simple json RPC server for now + auto server = setup_server(handler); + server->run(); - if (pm_thread && pm_thread->joinable()) { + km_thread.join(); + if (pm_thread) { pm_thread->join(); - } - if (gpumon_thread && gpumon_thread->joinable()) { + } + if (gpumon_thread) { gpumon_thread->join(); - } + } - server->stop(); + server->stop(); - return 0; + return 0; } diff --git a/msmonitor/dynolog_npu/dynolog/src/Metrics.cpp b/msmonitor/dynolog_npu/dynolog/src/Metrics.cpp index 959bb1778ce..dd3327b44b5 100644 --- a/msmonitor/dynolog_npu/dynolog/src/Metrics.cpp +++ b/msmonitor/dynolog_npu/dynolog/src/Metrics.cpp @@ -10,30 +10,90 @@ namespace dynolog { -const std::vector getAllMetrics() -{ - static std::vector metrics_ = { - {.name = "kindName", - .type = MetricType::Instant, - .desc = "Report data kind name"}, - {.name = "duration", - .type = MetricType::Delta, - .desc = "Total execution time for corresponding kind"}, - {.name = "timestamp", - .type = MetricType::Instant, - .desc = "The timestamp of the reported data"}, - {.name = "deviceId", - .type = MetricType::Instant, - .desc = "The ID of the device for reporting data"}, - }; - return metrics_; +const std::vector getAllMetrics() { + static std::vector metrics_ = { + {.name = "cpu_util", + .type = MetricType::Ratio, + .desc = "Fraction of total CPU time spend on user or system mode."}, + {.name = "cpu_u", + .type = MetricType::Ratio, + .desc = "Fraction of total CPU time spent in user mode."}, + {.name = "cpu_s", + .type = MetricType::Ratio, + .desc = "Fraction of total CPU time spent in system mode."}, + {.name = "cpu_i", + .type = MetricType::Ratio, + .desc = "Fraction of total CPU time spent in idle mode."}, + {.name = "mips", + .type = MetricType::Ratio, + .desc = "Number of million instructions executed per second."}, + {.name = "mega_cycles_per_second", + .type = MetricType::Ratio, + .desc = "Number of active CPU clock cycles per second."}, + {.name = "uptime", + .type = MetricType::Instant, + .desc = "How long the system has been running in seconds."}, + }; + static std::map cpustats = { + {"cpu_u_ms", "user"}, + {"cpu_s_ms", "system"}, + {"cpu_n_ms", "nice"}, + {"cpu_w_ms", "iowait"}, + {"cpu_x_ms", "irq"}, + {"cpu_y_ms", "softirq"}, + }; + + auto metrics = metrics_; + + for (const auto& [name, cpu_mode] : cpustats) { + MetricDesc m{ + .name = name, + .type = MetricType::Delta, + .desc = fmt::format( + "Total CPU time in milliseconds spent in {} mode. " + "For more details please see man page for /proc/stat", + cpu_mode)}; + metrics.push_back(m); + } + + return metrics; } // These metrics are dynamic per network drive -const std::vector getNetworkMetrics() -{ - static std::vector metrics_ = {}; - return metrics_; +const std::vector getNetworkMetrics() { + static std::vector metrics_ = { + {.name = "tx_bytes", + .type = MetricType::Delta, + .desc = + "Total bytes transmitted/received over the specific network device."}, + {.name = "rx_bytes", + .type = MetricType::Delta, + .desc = + "Total bytes transmitted/received over the specific network device."}, + {.name = "tx_packets", + .type = MetricType::Delta, + .desc = + "Total packets transmitted/received over the specific network device."}, + {.name = "rx_packets", + .type = MetricType::Delta, + .desc = + "Total packets transmitted/received over the specific network device."}, + {.name = "tx_errors", + .type = MetricType::Delta, + .desc = "Total transmit/receive errors on the specific network device."}, + {.name = "rx_errors", + .type = MetricType::Delta, + .desc = "Total transmit/receive errors on the specific network device."}, + {.name = "tx_drops", + .type = MetricType::Delta, + .desc = + "Total transmit/receive packet drops on the specific network device."}, + {.name = "rx_drops", + .type = MetricType::Delta, + .desc = + "Total transmit/receive packet drops on the specific network device."}, + }; + return metrics_; } -} // namespace dynolog \ No newline at end of file +} // namespace dynolog diff --git a/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.cpp b/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.cpp index dfe980a4013..579539f754e 100644 --- a/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.cpp +++ b/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.cpp @@ -10,107 +10,76 @@ #include #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -DEFINE_string(certs_dir, "", "TLS crets dir"); constexpr int CLIENT_QUEUE_LEN = 50; -const std::string NO_CERTS_MODE = "NO_CERTS"; -const size_t MIN_RSA_KEY_LENGTH = 3072; namespace dynolog { -void secure_clear_password(std::string& password); - -SimpleJsonServerBase::SimpleJsonServerBase(int port) : port_(port) -{ - try { - initSocket(); - if (FLAGS_certs_dir != NO_CERTS_MODE) { - init_openssl(); - ctx_ = create_context(); - configure_context(ctx_); - } - } catch (const std::exception& e) { - LOG(ERROR) << "Failed to initialize server: " << e.what(); - throw; - } -} - -SimpleJsonServerBase::~SimpleJsonServerBase() -{ - if (thread_) { - stop(); - } +SimpleJsonServerBase::SimpleJsonServerBase(int port) : port_(port) { + initSocket(); +} + +SimpleJsonServerBase::~SimpleJsonServerBase() { + if (thread_) { + stop(); + } + close(sock_fd_); +} + +void SimpleJsonServerBase::initSocket() { + struct sockaddr_in6 server_addr; + + /* Create socket for listening (client requests).*/ + sock_fd_ = ::socket(AF_INET6, SOCK_STREAM, 0); + if (sock_fd_ == -1) { + std::perror("socket()"); + return; + } + + /* Set socket to reuse address in case server is restarted.*/ + int flag = 1; + int ret = + ::setsockopt(sock_fd_, SOL_SOCKET, SO_REUSEADDR, &flag, sizeof(flag)); + if (ret == -1) { + std::perror("setsockopt()"); + return; + } + + // in6addr_any allows us to bind to both IPv4 and IPv6 clients. + server_addr.sin6_addr = in6addr_any; + server_addr.sin6_family = AF_INET6; + server_addr.sin6_port = htons(port_); + + /* Bind address and socket together */ + ret = ::bind(sock_fd_, (struct sockaddr*)&server_addr, sizeof(server_addr)); + if (ret == -1) { + std::perror("bind()"); close(sock_fd_); - if (FLAGS_certs_dir != NO_CERTS_MODE && ctx_) { - SSL_CTX_free(ctx_); - } -} - -void SimpleJsonServerBase::initSocket() -{ - struct sockaddr_in6 server_addr; - - /* Create socket for listening (client requests). */ - sock_fd_ = ::socket(AF_INET6, SOCK_STREAM, 0); - if (sock_fd_ == -1) { - std::perror("socket()"); - return; - } - - /* Set socket to reuse address in case server is restarted. */ - int flag = 1; - int ret = - ::setsockopt(sock_fd_, SOL_SOCKET, SO_REUSEADDR, &flag, sizeof(flag)); - if (ret == -1) { - std::perror("setsockopt()"); - return; - } - - // in6addr_any allows us to bind to both IPv4 and IPv6 clients. - server_addr.sin6_addr = in6addr_any; - server_addr.sin6_family = AF_INET6; - server_addr.sin6_port = htons(port_); + return; + } - /* Bind address and socket together */ - ret = ::bind(sock_fd_, (struct sockaddr*)&server_addr, sizeof(server_addr)); - if (ret == -1) { - std::perror("bind()"); - close(sock_fd_); - return; - } - - /* Create listening queue (client requests) */ - ret = ::listen(sock_fd_, CLIENT_QUEUE_LEN); - if (ret == -1) { - std::perror("listen()"); - close(sock_fd_); - return; - } - - /* Get port if assigned 0 */ - if (port_ == 0) { - socklen_t len_out = sizeof(server_addr); - ret = ::getsockname(sock_fd_, (struct sockaddr*)&server_addr, &len_out); - if (ret < 0 || len_out != sizeof(server_addr)) { - std::perror("getsockname()"); - } else { - port_ = ntohs(server_addr.sin6_port); - LOG(INFO) << "System assigned port = " << ntohs(server_addr.sin6_port); - } + /* Create listening queue (client requests) */ + ret = ::listen(sock_fd_, CLIENT_QUEUE_LEN); + if (ret == -1) { + std::perror("listen()"); + close(sock_fd_); + return; + } + + /* Get port if assigned 0 */ + if (port_ == 0) { + socklen_t len_out = sizeof(server_addr); + ret = ::getsockname(sock_fd_, (struct sockaddr*)&server_addr, &len_out); + if (ret < 0 || len_out != sizeof(server_addr)) { + std::perror("getsockname()"); + } else { + port_ = ntohs(server_addr.sin6_port); + LOG(INFO) << "System assigned port = " << ntohs(server_addr.sin6_port); } + } - LOG(INFO) << "Listening to connections on port " << port_; - initSuccess_ = true; + LOG(INFO) << "Listening to connections on port " << port_; + initSuccess_ = true; } /* A simple wrapper to accept connections and read data @@ -121,655 +90,141 @@ void SimpleJsonServerBase::initSocket() * : char json[] */ class ClientSocketWrapper { -public: - ~ClientSocketWrapper() - { - if (FLAGS_certs_dir != NO_CERTS_MODE && ssl_) { - SSL_shutdown(ssl_); - SSL_free(ssl_); - } - if (client_sock_fd_ != -1) { - ::close(client_sock_fd_); - } + public: + ~ClientSocketWrapper() { + if (::close(client_sock_fd_) < 0) { + std::perror("close()"); + } + } + + bool accept(int server_socket) { + struct sockaddr_in6 client_addr; + socklen_t client_addr_len = sizeof(client_addr); + std::array client_addr_str; + + client_sock_fd_ = ::accept( + server_socket, (struct sockaddr*)&client_addr, &client_addr_len); + if (client_sock_fd_ == -1) { + std::perror("accept()"); + return false; + } + + inet_ntop( + AF_INET6, + &(client_addr.sin6_addr), + client_addr_str.data(), + client_addr_str.size()); + LOG(INFO) << "Received connection from " << client_addr_str.data(); + return true; + } + + /* Reads an entire message from the client, + * expected to return json*/ + std::string get_message() { + /* Wait for data from client */ + int32_t msg_size = -1; + // TODO use decltype or type alias + if (!read_helper((uint8_t*)&msg_size, sizeof(msg_size)) || msg_size <= 0) { + LOG(ERROR) << "Invalid message size = " << msg_size; + return ""; + } + + std::string message; + message.resize(msg_size); + int recv = 0; + + /* Wait for data from client */ + int ret = 1; + while (recv < msg_size && ret > 0) { + ret = read_helper((uint8_t*)&message[recv], msg_size - recv); + recv += ret > 0 ? ret : 0; + } + + if (recv != msg_size) { + LOG(ERROR) << "Received parital message, expected size " << msg_size + << "found : " << recv; + LOG(ERROR) << "Message received = " << message; + return ""; + } + + return message; + } + + bool send_response(const std::string& response) { + // send size prefixed response + int32_t size = response.size(); + int ret = ::write(client_sock_fd_, (void*)&size, sizeof(size)); + if (ret == -1) { + std::perror("read()"); } - bool accept(int server_socket, SSL_CTX* ctx) - { - struct sockaddr_in6 client_addr; - socklen_t client_addr_len = sizeof(client_addr); - std::array client_addr_str; - - client_sock_fd_ = ::accept( - server_socket, (struct sockaddr*)&client_addr, &client_addr_len); - if (client_sock_fd_ == -1) { - std::perror("accept()"); - return false; - } - - inet_ntop( - AF_INET6, - &(client_addr.sin6_addr), - client_addr_str.data(), - client_addr_str.size()); - LOG(INFO) << "Received connection from " << client_addr_str.data(); - - if (FLAGS_certs_dir == NO_CERTS_MODE) { - LOG(INFO) << "No certs mode"; - return true; - } - - ssl_ = SSL_new(ctx); - SSL_set_fd(ssl_, client_sock_fd_); - if (SSL_accept(ssl_) <= 0) { - ERR_print_errors_fp(stderr); - return false; - } - LOG(INFO) << "SSL handshake success"; - return true; + int sent = 0; + while (sent < size && ret > 0) { + ret = ::write(client_sock_fd_, (void*)&response[sent], size - sent); + if (ret == -1) { + std::perror("read()"); + } else { + sent += ret; + } } - std::string get_message() - { - int32_t msg_size = -1; - if (!read_helper((uint8_t*)&msg_size, sizeof(msg_size)) || msg_size <= 0) { - LOG(ERROR) << "Invalid message size = " << msg_size; - return ""; - } - - std::string message; - message.resize(msg_size); - int recv = 0; - int ret = 1; - while (recv < msg_size && ret > 0) { - ret = read_helper((uint8_t*)&message[recv], msg_size - recv); - recv += ret > 0 ? ret : 0; - } - - if (recv != msg_size) { - LOG(ERROR) << "Received partial message, expected size " << msg_size - << " found : " << recv; - LOG(ERROR) << "Message received = " << message; - return ""; - } - - return message; + if (sent < response.size()) { + LOG(ERROR) << "Unable to write full response"; + return false; } + return ret > 0; + } - bool send_response(const std::string& response) - { - int32_t size = response.size(); - int ret; - if (FLAGS_certs_dir == NO_CERTS_MODE) { - ret = ::write(client_sock_fd_, (void*)&size, sizeof(size)); - if (ret == -1) { - std::perror("write()"); - return false; - } - } else { - ret = SSL_write(ssl_, (void*)&size, sizeof(size)); - if (ret <= 0) { - ERR_print_errors_fp(stderr); - return false; - } - } - int sent = 0; - while (sent < size && ret > 0) { - if (FLAGS_certs_dir == NO_CERTS_MODE) { - ret = ::write(client_sock_fd_, (void*)&response[sent], size - sent); - if (ret == -1) { - std::perror("write()"); - } else { - sent += ret; - } - } else { - ret = SSL_write(ssl_, (void*)&response[sent], size - sent); - if (ret <= 0) { - ERR_print_errors_fp(stderr); - } else { - sent += ret; - } - } - } - - if (sent < response.size()) { - LOG(ERROR) << "Unable to write full response"; - return false; - } - return ret > 0; + private: + int read_helper(uint8_t* buf, int size) { + int ret = ::read(client_sock_fd_, (void*)buf, size); + if (ret == -1) { + std::perror("read()"); } + return ret; + } -private: - int read_helper(uint8_t* buf, int size) - { - if (FLAGS_certs_dir == NO_CERTS_MODE) { - int ret = ::read(client_sock_fd_, (void*)buf, size); - if (ret == -1) { - std::perror("read()"); - } - return ret; - } - int ret = SSL_read(ssl_, (void*)buf, size); - if (ret <= 0) { - ERR_print_errors_fp(stderr); - } - return ret; - } - int client_sock_fd_ = -1; - SSL* ssl_ = nullptr; + int client_sock_fd_ = -1; }; /* Accepts socket connections and processes the payloads. - * This will inturn call the Handler functions */ -void SimpleJsonServerBase::loop() noexcept -{ - if (sock_fd_ == -1 || !initSuccess_) { - return; - } - - while (run_) { - processOne(); - } -} - -void SimpleJsonServerBase::processOne() noexcept -{ - LOG(INFO) << "Waiting for connection."; - ClientSocketWrapper client; - if (!client.accept(sock_fd_, ctx_)) { - return; - } - std::string request_str = client.get_message(); - LOG(INFO) << "RPC message received = " << request_str; - auto response_str = processOneImpl(request_str); - if (response_str.empty()) { - return; - } - if (!client.send_response(response_str)) { - LOG(ERROR) << "Failed to send response"; - } -} - -void SimpleJsonServerBase::run() -{ - LOG(INFO) << "Launching RPC thread"; - thread_ = std::make_unique([this]() { this->loop(); }); -} - -void SimpleJsonServerBase::init_openssl() -{ - SSL_load_error_strings(); - OpenSSL_add_ssl_algorithms(); -} - -SSL_CTX* SimpleJsonServerBase::create_context() -{ - const SSL_METHOD* method = TLS_server_method(); - SSL_CTX* ctx = SSL_CTX_new(method); - if (!ctx) { - ERR_print_errors_fp(stderr); - throw std::runtime_error("Unable to create SSL context"); - } - return ctx; -} - -static bool is_cert_revoked(X509* cert, X509_STORE* store) -{ - if (!cert || !store) { - LOG(ERROR) << "Invalid certificate or store pointer"; - return false; - } - // 获取证书的颁发者名称 - X509_NAME* issuer = X509_get_issuer_name(cert); - if (!issuer) { - LOG(ERROR) << "Failed to get certificate issuer"; - return false; - } - // 获取证书的序列号 - const ASN1_INTEGER* serial = X509_get_serialNumber(cert); - if (!serial) { - LOG(ERROR) << "Failed to get certificate serial number"; - return false; - } - // 创建证书验证上下文 - X509_STORE_CTX* ctx = X509_STORE_CTX_new(); - if (!ctx) { - LOG(ERROR) << "Failed to create certificate store context"; - return false; - } - bool is_revoked = false; - try { - // 初始化证书验证上下文 - if (!X509_STORE_CTX_init(ctx, store, cert, nullptr)) { - LOG(ERROR) << "Failed to initialize certificate store context"; - X509_STORE_CTX_free(ctx); - return false; - } - // 获取CRL列表 - STACK_OF(X509_CRL)* crls = X509_STORE_CTX_get1_crls(ctx, issuer); - if (!crls) { - LOG(INFO) << "No CRLs found for issuer"; - X509_STORE_CTX_free(ctx); - return false; - } - time_t current_time = time(nullptr); - for (int i = 0; i < sk_X509_CRL_num(crls); i++) { - X509_CRL* crl = sk_X509_CRL_value(crls, i); - if (!crl) { - LOG(ERROR) << "Invalid CRL at index " << i; - continue; - } - // 检查 CRL 的有效期 - const ASN1_TIME* crl_this_update = X509_CRL_get0_lastUpdate(crl); - const ASN1_TIME* crl_next_update = X509_CRL_get0_nextUpdate(crl); - if (!crl_this_update) { - LOG(ERROR) << "Failed to get CRL this_update time"; - continue; - } - // 检查 CRL 是否已生效 - if (X509_cmp_time(crl_this_update, ¤t_time) > 0) { - LOG(INFO) << "CRL is not yet valid"; - continue; - } - // 检查 CRL 是否过期 - if (crl_next_update) { - if (X509_cmp_time(crl_next_update, ¤t_time) < 0) { - LOG(INFO) << "CRL has expired"; - continue; - } - } - // 检查证书是否在 CRL 中 - STACK_OF(X509_REVOKED)* revoked = X509_CRL_get_REVOKED(crl); - if (revoked) { - for (int j = 0; j < sk_X509_REVOKED_num(revoked); j++) { - X509_REVOKED* rev = sk_X509_REVOKED_value(revoked, j); - if (rev) { - const ASN1_INTEGER* rev_serial = X509_REVOKED_get0_serialNumber(rev); - if (rev_serial && ASN1_INTEGER_cmp(serial, rev_serial) == 0) { - LOG(INFO) << "Certificate is found in CRL"; - is_revoked = true; - break; - } - } - } - } - if (is_revoked) { - break; - } - } - if (crls) { - sk_X509_CRL_pop_free(crls, X509_CRL_free); - } - } catch (const std::exception& e) { - LOG(ERROR) << "Exception while checking CRL: " << e.what(); - is_revoked = false; - } - X509_STORE_CTX_free(ctx); - return is_revoked; -} - - -// 禁用终端回显的函数,但显示星号 -std::string get_password_with_stars() -{ - struct termios old_flags; - struct termios new_flags; - std::string password; - - // 获取当前终端设置 - tcgetattr(fileno(stdin), &old_flags); - new_flags = old_flags; - new_flags.c_lflag &= ~ECHO; // 禁用回显 - - // 设置新的终端属性 - tcsetattr(fileno(stdin), TCSANOW, &new_flags); - - // 读取密码并显示星号 - char ch; - while ((ch = getchar()) != '\n') { - if (ch == 127 || ch == 8) { // 处理退格键 (ASCII 127 或 8) - if (!password.empty()) { - password.pop_back(); - std::cout << "\b \b"; // 删除一个星号 - std::cout.flush(); // 立即刷新输出 - } - } else { - password += ch; - std::cout << '*'; // 显示星号 - std::cout.flush(); // 立即刷新输出 - } - } - - // 恢复原来的终端设置 - tcsetattr(fileno(stdin), TCSANOW, &old_flags); + * This will inturn call the Handler functions*/ +void SimpleJsonServerBase::loop() noexcept { + if (sock_fd_ == -1 || !initSuccess_) { + return; + } - return password; + while (run_) { + processOne(); + } } -// 验证证书版本和签名算法 -void SimpleJsonServerBase::verify_certificate_version_and_algorithm(X509* cert) -{ - // 1. 检查证书版本是否为 X.509v3 - if (X509_get_version(cert) != 2) { // 2 表示 X.509v3 - throw std::runtime_error("Certificate is not X.509v3"); - } - - // 2. 检查证书签名算法 - const X509_ALGOR* sig_alg = X509_get0_tbs_sigalg(cert); - if (!sig_alg) { - throw std::runtime_error("Failed to get signature algorithm"); - } - - int sig_nid = OBJ_obj2nid(sig_alg->algorithm); - // 检查是否使用不安全的算法 - if (sig_nid == NID_md2WithRSAEncryption || - sig_nid == NID_md5WithRSAEncryption || - sig_nid == NID_sha1WithRSAEncryption) { - throw std::runtime_error("Certificate uses insecure signature algorithm: " + std::to_string(sig_nid)); - } -} - -// 验证 RSA 密钥长度 -void SimpleJsonServerBase::verify_rsa_key_length(EVP_PKEY* pkey) -{ - if (EVP_PKEY_base_id(pkey) == EVP_PKEY_RSA) { - size_t key_length = 0; -#if OPENSSL_VERSION_NUMBER >= 0x30000000L - // OpenSSL 3.0 及以上版本 - key_length = EVP_PKEY_get_size(pkey) * 8; // 转换为位数 -#else - // OpenSSL 1.1.1 及以下版本 - RSA* rsa = EVP_PKEY_get0_RSA(pkey); - if (!rsa) { - throw std::runtime_error("Failed to get RSA key"); - } - - const BIGNUM* n = nullptr; - RSA_get0_key(rsa, &n, nullptr, nullptr); - if (!n) { - throw std::runtime_error("Failed to get RSA modulus"); - } - - key_length = BN_num_bits(n); -#endif - if (key_length < MIN_RSA_KEY_LENGTH) { - throw std::runtime_error("RSA key length " + std::to_string(key_length) + " bits is less than required " + std::to_string(MIN_RSA_KEY_LENGTH) + " bits"); - } - } -} - -// 验证证书有效期 -void SimpleJsonServerBase::verify_certificate_validity(X509* cert) -{ - ASN1_TIME* not_before = X509_get_notBefore(cert); - ASN1_TIME* not_after = X509_get_notAfter(cert); - if (!not_before || !not_after) { - throw std::runtime_error("Failed to get certificate validity period"); - } - - time_t current_time = time(nullptr); - struct tm tm_before = {}; - struct tm tm_after = {}; - if (!ASN1_TIME_to_tm(not_before, &tm_before) || - !ASN1_TIME_to_tm(not_after, &tm_after)) { - throw std::runtime_error("Failed to convert certificate dates"); - } - - time_t not_before_time = mktime(&tm_before); - time_t not_after_time = mktime(&tm_after); - - // 检查证书是否已生效 - if (current_time < not_before_time) { - BIO* bio = BIO_new(BIO_s_mem()); - if (bio) { - ASN1_TIME_print(bio, not_before); - char* not_before_str = nullptr; - long len = BIO_get_mem_data(bio, ¬_before_str); - if (len > 0) { - std::string time_str(not_before_str, len); - BIO_free(bio); - throw std::runtime_error("Server certificate is not yet valid. Valid from: " + time_str); - } - BIO_free(bio); - } - throw std::runtime_error("Server certificate is not yet valid"); - } +void SimpleJsonServerBase::processOne() noexcept { + /* Wait for incoming connections, the accept call is blocking*/ + LOG(INFO) << "Waiting for connection."; + ClientSocketWrapper client; + if (!client.accept(sock_fd_)) { + return; + } - // 检查证书是否已过期 - if (current_time > not_after_time) { - BIO* bio = BIO_new(BIO_s_mem()); - if (bio) { - ASN1_TIME_print(bio, not_after); - char* not_after_str = nullptr; - long len = BIO_get_mem_data(bio, ¬_after_str); - if (len > 0) { - std::string time_str(not_after_str, len); - BIO_free(bio); - throw std::runtime_error("Server certificate has expired. Expired at: " + time_str); - } - BIO_free(bio); - } - throw std::runtime_error("Server certificate has expired"); - } -} + std::string request_str = client.get_message(); + LOG(INFO) << "RPC message received = " << request_str; -// 验证证书扩展域 -void SimpleJsonServerBase::verify_certificate_extensions(X509* cert) -{ - bool has_ca_constraint = false; - bool has_key_usage = false; - bool has_cert_sign = false; - bool has_crl_sign = false; - - const STACK_OF(X509_EXTENSION)* exts = X509_get0_extensions(cert); - if (exts) { - for (int i = 0; i < sk_X509_EXTENSION_num(exts); i++) { - X509_EXTENSION* ext = sk_X509_EXTENSION_value(exts, i); - ASN1_OBJECT* obj = X509_EXTENSION_get_object(ext); - - if (OBJ_obj2nid(obj) == NID_basic_constraints) { - BASIC_CONSTRAINTS* constraints = (BASIC_CONSTRAINTS*)X509V3_EXT_d2i(ext); - if (constraints) { - has_ca_constraint = constraints->ca; - BASIC_CONSTRAINTS_free(constraints); - } - } else if (OBJ_obj2nid(obj) == NID_key_usage) { - ASN1_BIT_STRING* usage = (ASN1_BIT_STRING*)X509V3_EXT_d2i(ext); - if (usage) { - has_key_usage = true; - has_cert_sign = (usage->data[0] & KU_KEY_CERT_SIGN) != 0; - has_crl_sign = (usage->data[0] & KU_CRL_SIGN) != 0; - ASN1_BIT_STRING_free(usage); - } - } - } - } + auto response_str = processOneImpl(request_str); + if (response_str.empty()) { + return; + } - if (has_ca_constraint) { - throw std::runtime_error("Client certificate should not have CA constraint"); - } - if (!has_key_usage) { - throw std::runtime_error("Client certificate must have key usage extension"); - } -} + if (!client.send_response(response_str)) { + LOG(ERROR) << "Failed to send response"; + } -// 加载私钥 -void SimpleJsonServerBase::load_private_key(SSL_CTX* ctx, const std::string& server_key) -{ - FILE* key_file = fopen(server_key.c_str(), "r"); - if (!key_file) { - throw std::runtime_error("Failed to open server key file"); - } - - bool is_encrypted = false; - char line[256]; - while (fgets(line, sizeof(line), key_file)) { - if (strstr(line, "ENCRYPTED")) { - is_encrypted = true; - break; - } - } - rewind(key_file); - - if (is_encrypted) { - std::string password; - std::cout << "Please enter the certificate password: "; - password = get_password_with_stars(); - std::cout << std::endl; - - EVP_PKEY* pkey = PEM_read_PrivateKey( - key_file, - nullptr, - [](char* buf, int size, int rwflag, void* userdata) -> int { - const std::string* password = static_cast(userdata); - if (password->size() > static_cast(size)) { - return 0; - } - password->copy(buf, password->size()); - return password->size(); - }, - const_cast(&password)); - - fclose(key_file); - secure_clear_password(password); - - if (!pkey) { - ERR_print_errors_fp(stderr); - throw std::runtime_error("Failed to load encrypted server private key"); - } - - if (SSL_CTX_use_PrivateKey(ctx, pkey) <= 0) { - EVP_PKEY_free(pkey); - ERR_print_errors_fp(stderr); - throw std::runtime_error("Failed to use server private key"); - } - - EVP_PKEY_free(pkey); - } else { - fclose(key_file); - if (SSL_CTX_use_PrivateKey_file(ctx, server_key.c_str(), SSL_FILETYPE_PEM) <= 0) { - ERR_print_errors_fp(stderr); - throw std::runtime_error("Failed to load server private key"); - } - } + return; } -// 加载和验证 CRL -void SimpleJsonServerBase::load_and_verify_crl(SSL_CTX* ctx, const std::string& crl_file) -{ - X509_STORE* store = SSL_CTX_get_cert_store(ctx); - if (!store) { - throw std::runtime_error("Failed to get certificate store"); - } - - if (access(crl_file.c_str(), F_OK) != -1) { - FILE* crl_file_ptr = fopen(crl_file.c_str(), "r"); - if (!crl_file_ptr) { - LOG(WARNING) << "Failed to open CRL file: " << crl_file; - return; - } - - X509_CRL* crl = PEM_read_X509_CRL(crl_file_ptr, nullptr, nullptr, nullptr); - fclose(crl_file_ptr); - - if (!crl) { - LOG(WARNING) << "Failed to read CRL from file: " << crl_file; - return; - } - - if (X509_STORE_add_crl(store, crl) != 1) { - LOG(WARNING) << "Failed to add CRL to certificate store"; - X509_CRL_free(crl); - return; - } - - X509* cert = SSL_CTX_get0_certificate(ctx); - if (!cert) { - X509_CRL_free(crl); - throw std::runtime_error("Failed to get server certificate"); - } - - if (is_cert_revoked(cert, store)) { - X509_CRL_free(crl); - throw std::runtime_error("Server certificate is revoked"); - } - - X509_CRL_free(crl); - } -} - -void SimpleJsonServerBase::configure_context(SSL_CTX* ctx) -{ - if (FLAGS_certs_dir.empty()) { - throw std::runtime_error("--certs-dir must be specified!"); - } - - std::string certs_dir = FLAGS_certs_dir; - if (!certs_dir.empty() && certs_dir.back() != '/') - certs_dir += '/'; - - std::string server_cert = certs_dir + "server.crt"; - std::string server_key = certs_dir + "server.key"; - std::string ca_cert = certs_dir + "ca.crt"; - std::string crl_file = certs_dir + "ca.crl"; - - LOG(INFO) << "Loading server cert: " << server_cert; - LOG(INFO) << "Loading server key: " << server_key; - LOG(INFO) << "Loading CA cert: " << ca_cert; - - // 1. 加载并验证服务器证书 - if (SSL_CTX_use_certificate_file(ctx, server_cert.c_str(), SSL_FILETYPE_PEM) <= 0) { - ERR_print_errors_fp(stderr); - throw std::runtime_error("Failed to load server certificate"); - } - - X509* cert = SSL_CTX_get0_certificate(ctx); - if (!cert) { - throw std::runtime_error("Failed to get server certificate"); - } - - // 2. 验证证书版本和签名算法 - verify_certificate_version_and_algorithm(cert); - - // 3. 验证 RSA 密钥长度 - EVP_PKEY* pkey = X509_get_pubkey(cert); - if (!pkey) { - throw std::runtime_error("Failed to get public key"); - } - verify_rsa_key_length(pkey); - EVP_PKEY_free(pkey); - - // 4. 验证证书有效期 - verify_certificate_validity(cert); - - // 5. 验证证书扩展域 - verify_certificate_extensions(cert); - - // 6. 加载私钥 - load_private_key(ctx, server_key); - - // 7. 加载 CA 证书 - if (SSL_CTX_load_verify_locations(ctx, ca_cert.c_str(), NULL) <= 0) { - ERR_print_errors_fp(stderr); - throw std::runtime_error("Failed to load CA certificate"); - } - - // 8. 加载和验证 CRL - load_and_verify_crl(ctx, crl_file); - - // 9. 设置证书验证选项 - SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL); -} - -void secure_clear_password(std::string& password) -{ - if (!password.empty()) { - // 使用随机数据覆盖密码 - std::generate(password.begin(), password.end(), std::rand); - // 清空字符串 - password.clear(); - // 收缩字符串容量,释放内存 - password.shrink_to_fit(); - } +void SimpleJsonServerBase::run() { + LOG(INFO) << "Launching RPC thread"; + thread_ = std::make_unique([this]() { this->loop(); }); } -} // namespace dynolog \ No newline at end of file +} // namespace dynolog diff --git a/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.h b/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.h index 9efe0d086ca..52be6e4cf30 100644 --- a/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.h +++ b/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.h @@ -10,71 +10,54 @@ #include #include #include -#include -#include #include "dynolog/src/ServiceHandler.h" -DECLARE_string(certs_dir); - namespace dynolog { + // This is a simple service built using UNIX Sockets // with remote procedure calls implemented via JSON string. + class SimpleJsonServerBase { -public: - explicit SimpleJsonServerBase(int port); - virtual ~SimpleJsonServerBase(); - - int getPort() const - { - return port_; - } - - bool initSuccessful() const - { - return initSuccess_; - } - // spin up a new thread to process requets - void run(); - - void stop() - { - run_ = 0; - thread_->join(); - } - - // synchronously processes a request - void processOne() noexcept; - -protected: - void initSocket(); - void init_openssl(); - SSL_CTX *create_context(); - void configure_context(SSL_CTX *ctx); - - // process requests in a loop - void loop() noexcept; - - // implement processing of request using the handler - virtual std::string processOneImpl(const std::string &request_str) - { - return ""; - } - - void verify_certificate_version_and_algorithm(X509 *cert); - void verify_rsa_key_length(EVP_PKEY *pkey); - void verify_certificate_validity(X509 *cert); - void verify_certificate_extensions(X509 *cert); - void load_private_key(SSL_CTX *ctx, const std::string &server_key); - void load_and_verify_crl(SSL_CTX *ctx, const std::string &crl_file); - - int port_; - int sock_fd_{-1}; - bool initSuccess_{false}; - - std::atomic run_{true}; - std::unique_ptr thread_; - - SSL_CTX *ctx_{nullptr}; + public: + explicit SimpleJsonServerBase(int port); + + virtual ~SimpleJsonServerBase(); + + int getPort() const { + return port_; + } + + bool initSuccessful() const { + return initSuccess_; + } + // spin up a new thread to process requets + void run(); + + void stop() { + run_ = 0; + thread_->join(); + } + + // synchronously processes a request + void processOne() noexcept; + + protected: + void initSocket(); + + // process requests in a loop + void loop() noexcept; + + // implement processing of request using the handler + virtual std::string processOneImpl(const std::string& request_str) { + return ""; + } + + int port_; + int sock_fd_{-1}; + bool initSuccess_{false}; + + std::atomic run_{true}; + std::unique_ptr thread_; }; -} // namespace dynolog \ No newline at end of file +} // namespace dynolog diff --git a/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.cpp b/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.cpp index 2f419d3796a..8e5f0fe1757 100644 --- a/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.cpp +++ b/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.cpp @@ -5,8 +5,7 @@ #include "dynolog/src/tracing/IPCMonitor.h" #include -#include -#include +#include #include #include #include @@ -21,169 +20,105 @@ namespace dynolog { namespace tracing { constexpr int kSleepUs = 10000; -constexpr int kDataMsgSleepUs = 1000; const std::string kLibkinetoRequest = "req"; const std::string kLibkinetoContext = "ctxt"; -const std::string kLibkinetoData = "data"; -IPCMonitor::IPCMonitor(const std::string& ipc_fabric_name) -{ - ipc_manager_ = FabricManager::factory(ipc_fabric_name); - data_ipc_manager_ = FabricManager::factory(ipc_fabric_name + "_data"); - // below ensures singleton exists - LOG(INFO) << "Kineto config manager : active processes = " - << LibkinetoConfigManager::getInstance()->processCount("0"); +IPCMonitor::IPCMonitor(const std::string& ipc_fabric_name) { + ipc_manager_ = FabricManager::factory(ipc_fabric_name); + // below ensures singleton exists + LOG(INFO) << "Kineto config manager : active processes = " + << LibkinetoConfigManager::getInstance()->processCount("0"); } -void IPCMonitor::loop() -{ - while (ipc_manager_) { - if (ipc_manager_->recv()) { - std::unique_ptr msg = ipc_manager_->retrieve_msg(); - processMsg(std::move(msg)); - } - /* sleep override */ - usleep(kSleepUs); +void IPCMonitor::loop() { + while (ipc_manager_) { + if (ipc_manager_->recv()) { + std::unique_ptr msg = ipc_manager_->retrieve_msg(); + processMsg(std::move(msg)); } + /* sleep override */ + usleep(kSleepUs); + } } -void IPCMonitor::dataLoop() -{ - while (data_ipc_manager_) { - if (data_ipc_manager_->recv()) { - std::unique_ptr msg = data_ipc_manager_->retrieve_msg(); - processDataMsg(std::move(msg)); - } - /* sleep override */ - usleep(kDataMsgSleepUs); - } -} - -void IPCMonitor::processMsg(std::unique_ptr msg) -{ - if (!ipc_manager_) { - LOG(ERROR) << "Fabric Manager not initialized"; - return; - } - // sizeof(msg->metadata.type) = 32, well above the size of the constant - // strings we are comparing against. memcmp is safe - if (memcmp( // NOLINT(facebook-security-vulnerable-memcmp) - msg->metadata.type, - kLibkinetoContext.data(), - kLibkinetoContext.size()) == 0) { - registerLibkinetoContext(std::move(msg)); - } else if ( - memcmp( // NOLINT(facebook-security-vulnerable-memcmp) - msg->metadata.type, - kLibkinetoRequest.data(), - kLibkinetoRequest.size()) == 0) { - getLibkinetoOnDemandRequest(std::move(msg)); - } else { - LOG(ERROR) << "TYPE UNKOWN: " << msg->metadata.type; - } -} - -void tracing::IPCMonitor::setLogger(std::unique_ptr logger) -{ - logger_ = std::move(logger); -} - -void IPCMonitor::LogData(const nlohmann::json& result) -{ - auto timestamp = result["timestamp"].get(); - logger_->logUint("timestamp", timestamp); - auto duration = result["duration"].get(); - logger_->logUint("duration", duration); - auto deviceId = result["deviceId"].get(); - logger_->logStr("deviceId", std::to_string(deviceId)); - auto kind = result["kind"].get(); - logger_->logStr("kind", kind); - if (result.contains("domain") && result["domain"].is_string()) { - auto domain = result["domain"].get(); - logger_->logStr("domain", domain); - } - logger_->finalize(); -} - -void IPCMonitor::processDataMsg(std::unique_ptr msg) -{ - if (!data_ipc_manager_) { - LOG(ERROR) << "Fabric Manager not initialized"; - return; - } - if (memcmp( // NOLINT(facebook-security-vulnerable-memcmp) - msg->metadata.type, - kLibkinetoData.data(), - kLibkinetoData.size()) == 0) { - std::string message = std::string((char*)msg->buf.get(), msg->metadata.size); - try { - nlohmann::json result = nlohmann::json::parse(message); - LOG(INFO) << "Received data message : " << result; - LogData(result); - } catch (nlohmann::json::parse_error&) { - LOG(ERROR) << "Error parsing message = " << message; - return; - } - } else { - LOG(ERROR) << "TYPE UNKOWN: " << msg->metadata.type; - } +void IPCMonitor::processMsg(std::unique_ptr msg) { + if (!ipc_manager_) { + LOG(ERROR) << "Fabric Manager not initialized"; + return; + } + // sizeof(msg->metadata.type) = 32, well above the size of the constant + // strings we are comparing against. memcmp is safe + if (memcmp( // NOLINT(facebook-security-vulnerable-memcmp) + msg->metadata.type, + kLibkinetoContext.data(), + kLibkinetoContext.size()) == 0) { + registerLibkinetoContext(std::move(msg)); + } else if ( + memcmp( // NOLINT(facebook-security-vulnerable-memcmp) + msg->metadata.type, + kLibkinetoRequest.data(), + kLibkinetoRequest.size()) == 0) { + getLibkinetoOnDemandRequest(std::move(msg)); + } else { + LOG(ERROR) << "TYPE UNKOWN: " << msg->metadata.type; + } } void IPCMonitor::getLibkinetoOnDemandRequest( - std::unique_ptr msg) -{ - if (!ipc_manager_) { - LOG(ERROR) << "Fabric Manager not initialized"; - return; - } - std::string ret_config = ""; - ipcfabric::LibkinetoRequest* req = - (ipcfabric::LibkinetoRequest*)msg->buf.get(); - if (req->n == 0) { - LOG(ERROR) << "Missing pids parameter for type " << req->type; - return; - } - std::vector pids(req->pids, req->pids + req->n); - try { - ret_config = LibkinetoConfigManager::getInstance()->obtainOnDemandConfig( - std::to_string(req->jobid), pids, req->type); - VLOG(0) << "getLibkinetoOnDemandRequest() : job id " << req->jobid - << " pids = " << pids[0]; - } catch (const std::runtime_error& ex) { - LOG(ERROR) << "Kineto config manager exception : " << ex.what(); - } - std::unique_ptr ret = - ipcfabric::Message::constructMessage( - ret_config, kLibkinetoRequest); - if (!ipc_manager_->sync_send(*ret, msg->src)) { - LOG(ERROR) << "Failed to return config to libkineto: IPC sync_send fail"; - } + std::unique_ptr msg) { + if (!ipc_manager_) { + LOG(ERROR) << "Fabric Manager not initialized"; + return; + } + std::string ret_config = ""; + ipcfabric::LibkinetoRequest* req = + (ipcfabric::LibkinetoRequest*)msg->buf.get(); + if (req->n == 0) { + LOG(ERROR) << "Missing pids parameter for type " << req->type; return; + } + std::vector pids(req->pids, req->pids + req->n); + try { + ret_config = LibkinetoConfigManager::getInstance()->obtainOnDemandConfig( + std::to_string(req->jobid), pids, req->type); + VLOG(0) << "getLibkinetoOnDemandRequest() : job id " << req->jobid + << " pids = " << pids[0]; + } catch (const std::runtime_error& ex) { + LOG(ERROR) << "Kineto config manager exception : " << ex.what(); + } + std::unique_ptr ret = + ipcfabric::Message::constructMessage( + ret_config, kLibkinetoRequest); + if (!ipc_manager_->sync_send(*ret, msg->src)) { + LOG(ERROR) << "Failed to return config to libkineto: IPC sync_send fail"; + } + + return; } void IPCMonitor::registerLibkinetoContext( - std::unique_ptr msg) -{ - if (!ipc_manager_) { - LOG(ERROR) << "Fabric Manager not initialized"; - return; - } - ipcfabric::LibkinetoContext* ctxt = - (ipcfabric::LibkinetoContext*)msg->buf.get(); - int32_t size = -1; - try { - size = LibkinetoConfigManager::getInstance()->registerLibkinetoContext( - std::to_string(ctxt->jobid), ctxt->pid, ctxt->gpu); - } catch (const std::runtime_error& ex) { - LOG(ERROR) << "Kineto config manager exception : " << ex.what(); - } - std::unique_ptr ret = - ipcfabric::Message::constructMessage( - size, kLibkinetoContext); - if (!ipc_manager_->sync_send(*ret, msg->src)) { - LOG(ERROR) << "Failed to send ctxt from dyno: IPC sync_send fail"; - } + std::unique_ptr msg) { + if (!ipc_manager_) { + LOG(ERROR) << "Fabric Manager not initialized"; return; + } + ipcfabric::LibkinetoContext* ctxt = + (ipcfabric::LibkinetoContext*)msg->buf.get(); + int32_t size = -1; + try { + size = LibkinetoConfigManager::getInstance()->registerLibkinetoContext( + std::to_string(ctxt->jobid), ctxt->pid, ctxt->gpu); + } catch (const std::runtime_error& ex) { + LOG(ERROR) << "Kineto config manager exception : " << ex.what(); + } + std::unique_ptr ret = + ipcfabric::Message::constructMessage( + size, kLibkinetoContext); + if (!ipc_manager_->sync_send(*ret, msg->src)) { + LOG(ERROR) << "Failed to send ctxt from dyno: IPC sync_send fail"; + } + + return; } } // namespace tracing diff --git a/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.h b/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.h index cbc59fd2bbc..25d10e2ec7f 100644 --- a/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.h +++ b/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.h @@ -11,34 +11,27 @@ #define USE_GOOGLE_LOG #include "dynolog/src/ipcfabric/FabricManager.h" -#include "dynolog/src/Logger.h" namespace dynolog { namespace tracing { class IPCMonitor { -public: - using FabricManager = dynolog::ipcfabric::FabricManager; - IPCMonitor(const std::string& ipc_fabric_name = "dynolog"); - virtual ~IPCMonitor() {} - - void loop(); - void dataLoop(); - -public: - virtual void processMsg(std::unique_ptr msg); - virtual void processDataMsg(std::unique_ptr msg); - void getLibkinetoOnDemandRequest(std::unique_ptr msg); - void registerLibkinetoContext(std::unique_ptr msg); - void setLogger(std::unique_ptr logger); - void LogData(const nlohmann::json& result); - - std::unique_ptr ipc_manager_; - std::unique_ptr data_ipc_manager_; - std::unique_ptr logger_; + public: + using FabricManager = dynolog::ipcfabric::FabricManager; + IPCMonitor(const std::string& ipc_fabric_name = "dynolog"); + virtual ~IPCMonitor() {} + + void loop(); + + public: + virtual void processMsg(std::unique_ptr msg); + void getLibkinetoOnDemandRequest(std::unique_ptr msg); + void registerLibkinetoContext(std::unique_ptr msg); + + std::unique_ptr ipc_manager_; // friend class test_case_name##_##test_name##_Test - friend class IPCMonitorTest_LibkinetoRegisterAndOndemandTest_Test; + friend class IPCMonitorTest_LibkinetoRegisterAndOndemandTest_Test; }; } // namespace tracing -- Gitee From cbc2d919189d66fd35f4daf388daa2dded4cdc85 Mon Sep 17 00:00:00 2001 From: tangmengcheng <745274877@qq.com> Date: Thu, 19 Jun 2025 11:36:09 +0800 Subject: [PATCH 2/2] dynolog_npu modify --- msmonitor/dynolog_npu/cli/Cargo.toml | 23 +- .../dynolog_npu/cli/src/commands/dcgm.rs | 47 +- .../dynolog_npu/cli/src/commands/gputrace.rs | 17 +- msmonitor/dynolog_npu/cli/src/commands/mod.rs | 9 +- .../dynolog_npu/cli/src/commands/status.rs | 23 +- .../dynolog_npu/cli/src/commands/utils.rs | 65 +- .../dynolog_npu/cli/src/commands/version.rs | 20 +- msmonitor/dynolog_npu/cli/src/main.rs | 655 ++++++++++++- msmonitor/dynolog_npu/dynolog/src/Main.cpp | 154 +-- msmonitor/dynolog_npu/dynolog/src/Metrics.cpp | 104 +- .../dynolog/src/rpc/SimpleJsonServer.cpp | 899 ++++++++++++++---- .../dynolog/src/rpc/SimpleJsonServer.h | 103 +- .../dynolog/src/tracing/IPCMonitor.cpp | 233 +++-- .../dynolog/src/tracing/IPCMonitor.h | 35 +- 14 files changed, 1804 insertions(+), 583 deletions(-) diff --git a/msmonitor/dynolog_npu/cli/Cargo.toml b/msmonitor/dynolog_npu/cli/Cargo.toml index 479ec29ce25..5b5917add32 100644 --- a/msmonitor/dynolog_npu/cli/Cargo.toml +++ b/msmonitor/dynolog_npu/cli/Cargo.toml @@ -7,8 +7,27 @@ edition = "2021" anyhow = "1.0.57" clap = { version = "3.1.0", features = ["derive"]} serde_json = "1.0" +rustls = "0.21.0" +rustls-pemfile = "1.0" +webpki = "0.22" +x509-parser = "0.15" +der-parser = "8" +pem = "1.1" +chrono = "0.4" +num-bigint = "0.4" +openssl = { version = "0.10", features = ["vendored"] } +rpassword = "7.2.0" -# Make it work with conda -# See https://github.com/rust-lang/cargo/issues/6652 [net] git-fetch-with-cli = true + +[build] +rustflags = [ + "-C", "relocation_model=pie", + "-C", "link-args=-Wl,-z,now", + "-C", "link-args=-Wl,-z,relro", + "-C", "strip=symbols", + "-C", "overflow_checks", + "-C", "link-args=-static-libgcc", + "-C", "link-args=-static-libstdc++" +] diff --git a/msmonitor/dynolog_npu/cli/src/commands/dcgm.rs b/msmonitor/dynolog_npu/cli/src/commands/dcgm.rs index 89fd9ec95c6..9f342c04dfd 100644 --- a/msmonitor/dynolog_npu/cli/src/commands/dcgm.rs +++ b/msmonitor/dynolog_npu/cli/src/commands/dcgm.rs @@ -3,43 +3,30 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use std::net::TcpStream; - use anyhow::Result; - -#[path = "utils.rs"] -mod utils; +use crate::DynoClient; +use super::utils; // This module contains the handling logic for dcgm /// Pause dcgm module profiling -pub fn run_dcgm_pause(client: TcpStream, duration_s: i32) -> Result<()> { - let request_json = format!( - r#" -{{ - "fn": "dcgmProfPause", - "duration_s": {} -}}"#, - duration_s - ); - - utils::send_msg(&client, &request_json).expect("Error sending message to service"); - - let resp_str = utils::get_resp(&client).expect("Unable to decode output bytes"); - - println!("response = {}", resp_str); - +pub fn run_dcgm_pause( + mut client: DynoClient, + duration_s: i32, +) -> Result<()> { + let msg = format!(r#"{{"fn":"dcgmPause", "duration_s":{}}}"#, duration_s); + utils::send_msg(&mut client, &msg)?; + let resp_str = utils::get_resp(&mut client)?; + println!("{}", resp_str); Ok(()) } /// Resume dcgm module profiling -pub fn run_dcgm_resume(client: TcpStream) -> Result<()> { - utils::send_msg(&client, r#"{"fn":"dcgmProfResume"}"#) - .expect("Error sending message to service"); - - let resp_str = utils::get_resp(&client).expect("Unable to decode output bytes"); - - println!("response = {}", resp_str); - +pub fn run_dcgm_resume( + mut client: DynoClient, +) -> Result<()> { + utils::send_msg(&mut client, r#"{"fn":"dcgmResume"}"#)?; + let resp_str = utils::get_resp(&mut client)?; + println!("{}", resp_str); Ok(()) -} +} \ No newline at end of file diff --git a/msmonitor/dynolog_npu/cli/src/commands/gputrace.rs b/msmonitor/dynolog_npu/cli/src/commands/gputrace.rs index c2e37af7e31..677ccf27d10 100644 --- a/msmonitor/dynolog_npu/cli/src/commands/gputrace.rs +++ b/msmonitor/dynolog_npu/cli/src/commands/gputrace.rs @@ -3,13 +3,10 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use std::net::TcpStream; - use anyhow::Result; use serde_json::Value; - -#[path = "utils.rs"] -mod utils; +use crate::DynoClient; +use super::utils; // This module contains the handling logic for dyno gputrace @@ -95,7 +92,7 @@ impl GpuTraceConfig { /// Gputrace command triggers GPU profiling on pytorch apps pub fn run_gputrace( - client: TcpStream, + mut client: DynoClient, job_id: u64, pids: &str, process_limit: u32, @@ -117,11 +114,11 @@ pub fn run_gputrace( kineto_config, job_id, pids, process_limit ); - utils::send_msg(&client, &request_json).expect("Error sending message to service"); + utils::send_msg(&mut client, &request_json)?; - let resp_str = utils::get_resp(&client).expect("Unable to decode output bytes"); + let resp_str = utils::get_resp(&mut client)?; - println!("response = {}", resp_str); + println!("response = {}\n", resp_str); let resp_v: Value = serde_json::from_str(&resp_str)?; let processes = resp_v["processesMatched"].as_array().unwrap(); @@ -213,4 +210,4 @@ PROFILE_WITH_FLOPS=false PROFILE_WITH_MODULES=true"# ); } -} +} \ No newline at end of file diff --git a/msmonitor/dynolog_npu/cli/src/commands/mod.rs b/msmonitor/dynolog_npu/cli/src/commands/mod.rs index 5acfacb640e..1f3bf17ad79 100644 --- a/msmonitor/dynolog_npu/cli/src/commands/mod.rs +++ b/msmonitor/dynolog_npu/cli/src/commands/mod.rs @@ -9,8 +9,11 @@ // handling code. Additionally, explicitly "exporting" all the command modules here allows // us to avoid having to explicitly list all the command modules in main.rs. -pub mod dcgm; -pub mod gputrace; pub mod status; pub mod version; -// ... add new command modules here +pub mod dcgm; +pub mod gputrace; +pub mod nputrace; +pub mod npumonitor; +pub mod utils; +// ... add new command modules here \ No newline at end of file diff --git a/msmonitor/dynolog_npu/cli/src/commands/status.rs b/msmonitor/dynolog_npu/cli/src/commands/status.rs index c5d9d75d909..1be17956c12 100644 --- a/msmonitor/dynolog_npu/cli/src/commands/status.rs +++ b/msmonitor/dynolog_npu/cli/src/commands/status.rs @@ -3,22 +3,13 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use std::net::TcpStream; - use anyhow::Result; +use crate::DynoClient; +use super::utils; -#[path = "utils.rs"] -mod utils; - -// This module contains the handling logic for dyno status - -/// Get system info -pub fn run_status(client: TcpStream) -> Result<()> { - utils::send_msg(&client, r#"{"fn":"getStatus"}"#).expect("Error sending message to service"); - - let resp_str = utils::get_resp(&client).expect("Unable to decode output bytes"); - - println!("response = {}", resp_str); - +pub fn run_status(mut client: DynoClient) -> Result<()> { + utils::send_msg(&mut client, r#"{"fn":"getStatus"}"#)?; + let resp_str = utils::get_resp(&mut client)?; + println!("{}", resp_str); Ok(()) -} +} \ No newline at end of file diff --git a/msmonitor/dynolog_npu/cli/src/commands/utils.rs b/msmonitor/dynolog_npu/cli/src/commands/utils.rs index b156c1c812d..c2fdd3de618 100644 --- a/msmonitor/dynolog_npu/cli/src/commands/utils.rs +++ b/msmonitor/dynolog_npu/cli/src/commands/utils.rs @@ -3,33 +3,48 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use std::io::Read; -use std::io::Write; -use std::net::TcpStream; +use std::io::{Read, Write}; use anyhow::Result; -pub fn send_msg(mut client: &TcpStream, msg: &str) -> Result<()> { - let msg_len: [u8; 4] = i32::try_from(msg.len()).unwrap().to_ne_bytes(); - - client.write_all(&msg_len)?; - client.write_all(msg.as_bytes()).map_err(|err| err.into()) +use crate::DynoClient; + +pub fn send_msg(client: &mut DynoClient, msg: &str) -> Result<()> { + match client { + DynoClient::Secure(secure_client) => { + let msg_len: [u8; 4] = i32::try_from(msg.len()).unwrap().to_ne_bytes(); + secure_client.write_all(&msg_len)?; + secure_client.write_all(msg.as_bytes())?; + secure_client.flush()?; + } + DynoClient::Insecure(insecure_client) => { + let msg_len: [u8; 4] = i32::try_from(msg.len()).unwrap().to_ne_bytes(); + insecure_client.write_all(&msg_len)?; + insecure_client.write_all(msg.as_bytes())?; + insecure_client.flush()?; + } + } + Ok(()) } -pub fn get_resp(mut client: &TcpStream) -> Result { - // Response is prefixed with length - let mut resp_len: [u8; 4] = [0; 4]; - client.read_exact(&mut resp_len)?; - - let resp_len = i32::from_ne_bytes(resp_len); - let resp_len = usize::try_from(resp_len).unwrap(); - - println!("response length = {}", resp_len); - - let mut resp_str = Vec::::new(); - resp_str.resize(resp_len, 0); - - client.read_exact(resp_str.as_mut_slice())?; - - String::from_utf8(resp_str).map_err(|err| err.into()) -} +pub fn get_resp(client: &mut DynoClient) -> Result { + let mut len_buf = [0u8; 4]; + let mut resp_buf; + + match client { + DynoClient::Secure(secure_client) => { + secure_client.read_exact(&mut len_buf)?; + let len = u32::from_ne_bytes(len_buf) as usize; + resp_buf = vec![0u8; len]; + secure_client.read_exact(&mut resp_buf)?; + } + DynoClient::Insecure(insecure_client) => { + insecure_client.read_exact(&mut len_buf)?; + let len = u32::from_ne_bytes(len_buf) as usize; + resp_buf = vec![0u8; len]; + insecure_client.read_exact(&mut resp_buf)?; + } + } + + Ok(String::from_utf8(resp_buf)?) +} \ No newline at end of file diff --git a/msmonitor/dynolog_npu/cli/src/commands/version.rs b/msmonitor/dynolog_npu/cli/src/commands/version.rs index 94b5129f47c..31139d56852 100644 --- a/msmonitor/dynolog_npu/cli/src/commands/version.rs +++ b/msmonitor/dynolog_npu/cli/src/commands/version.rs @@ -3,22 +3,16 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use std::net::TcpStream; - use anyhow::Result; - -#[path = "utils.rs"] -mod utils; +use crate::DynoClient; +use super::utils; // This module contains the handling logic for querying dyno version /// Get version info -pub fn run_version(client: TcpStream) -> Result<()> { - utils::send_msg(&client, r#"{"fn":"getVersion"}"#).expect("Error sending message to service"); - - let resp_str = utils::get_resp(&client).expect("Unable to decode output bytes"); - - println!("response = {}", resp_str); - +pub fn run_version(mut client: DynoClient) -> Result<()> { + utils::send_msg(&mut client, r#"{"fn":"getVersion"}"#)?; + let resp_str = utils::get_resp(&mut client)?; + println!("{}", resp_str); Ok(()) -} +} \ No newline at end of file diff --git a/msmonitor/dynolog_npu/cli/src/main.rs b/msmonitor/dynolog_npu/cli/src/main.rs index b7c755e61ba..2bd85a79637 100644 --- a/msmonitor/dynolog_npu/cli/src/main.rs +++ b/msmonitor/dynolog_npu/cli/src/main.rs @@ -2,18 +2,38 @@ // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. - +use std::fs::File; +use std::io::BufReader; +use rustls::{Certificate, RootCertStore, PrivateKey, ClientConnection, StreamOwned}; +use std::sync::Arc; use std::net::TcpStream; use std::net::ToSocketAddrs; +use std::path::PathBuf; +use std::io; +use rpassword::prompt_password; use anyhow::Result; use clap::Parser; +use std::collections::HashSet; + +use x509_parser::prelude::*; +use x509_parser::num_bigint::ToBigInt; +use std::fs::read_to_string; +use x509_parser::public_key::RSAPublicKey; +use x509_parser::der_parser::oid; +use num_bigint::BigUint; +use openssl::pkey::PKey; +use std::io::Read; // Make all the command modules accessible to this file. mod commands; use commands::gputrace::GpuTraceConfig; use commands::gputrace::GpuTraceOptions; use commands::gputrace::GpuTraceTriggerConfig; +use commands::nputrace::NpuTraceConfig; +use commands::nputrace::NpuTraceOptions; +use commands::nputrace::NpuTraceTriggerConfig; +use commands::npumonitor::NpuMonitorConfig; use commands::*; /// Instructions on adding a new Dyno CLI command: @@ -32,6 +52,7 @@ use commands::*; /// the command dispatching logic clear and concise, please keep the code in the match branch to a minimum. const DYNO_PORT: u16 = 1778; +const MIN_RSA_KEY_LENGTH: u64 = 3072; // 最小 RSA 密钥长度(位) #[derive(Debug, Parser)] struct Opts { @@ -39,10 +60,49 @@ struct Opts { hostname: String, #[clap(long, default_value_t = DYNO_PORT)] port: u16, + #[clap(long, required = true)] + certs_dir: String, #[clap(subcommand)] cmd: Command, } +const ALLOWED_VALUES: &[&str] = &["Marker", "Kernel", "API", "Hccl", "Memory", "MemSet", "MemCpy"]; + +fn parse_mspti_activity_kinds(src: &str) -> Result{ + let allowed_values: HashSet<&str> = ALLOWED_VALUES.iter().cloned().collect(); + + let kinds: Vec<&str> = src.split(',').map(|s| s.trim()).collect(); + + for kind in &kinds { + if !allowed_values.contains(kind) { + return Err(format!("Invalid MSPTI activity kind: {}, Possible values: {:?}.]", kind, allowed_values)); + } + } + + Ok(src.to_string()) +} + +const ALLOWED_HOST_SYSTEM_VALUES: &[&str] = &["cpu", "mem", "disk", "network", "osrt"]; + +fn parse_host_sys(src: &str) -> Result{ + if src == "None" { + return Ok(src.to_string()); + } + + let allowed_host_sys_values: HashSet<&str> = ALLOWED_HOST_SYSTEM_VALUES.iter().cloned().collect(); + + let host_systems: Vec<&str> = src.split(',').map(|s| s.trim()).collect(); + + for host_system in &host_systems { + if !allowed_host_sys_values.contains(host_system) { + return Err(format!("Invalid NPU Trace host system: {}, Possible values: {:?}.]", host_system, + allowed_host_sys_values)); + } + } + let result = host_systems.join(","); + Ok(result) +} + #[derive(Debug, Parser)] enum Command { /// Check the status of a dynolog process @@ -51,7 +111,7 @@ enum Command { Version, /// Capture gputrace Gputrace { - /// Job id of the application to trace + /// Job id of the application to trace. #[clap(long, default_value_t = 0)] job_id: u64, /// List of pids to capture trace for (comma separated). @@ -66,32 +126,134 @@ enum Command { /// Log file for trace. #[clap(long)] log_file: String, - /// Unix timestamp used for synchronized collection (milliseconds since epoch) + /// Unix timestamp used for synchronized collection (milliseconds since epoch). #[clap(long, default_value_t = 0)] profile_start_time: u64, /// Start iteration roundup, starts an iteration based trace at a multiple /// of this value. #[clap(long, default_value_t = 1)] profile_start_iteration_roundup: u64, - /// Max number of processes to profile + /// Max number of processes to profile. #[clap(long, default_value_t = 3)] process_limit: u32, - /// Record PyTorch operator input shapes and types + /// Record PyTorch operator input shapes and types. #[clap(long, action)] record_shapes: bool, - /// Profile PyTorch memory + /// Profile PyTorch memory. #[clap(long, action)] profile_memory: bool, - /// Capture Python stacks in traces + /// Capture Python stacks in traces. #[clap(long, action)] with_stacks: bool, - /// Annotate operators with analytical flops + /// Annotate operators with analytical flops. #[clap(long, action)] with_flops: bool, - /// Capture PyTorch operator modules in traces + /// Capture PyTorch operator modules in traces. #[clap(long, action)] with_modules: bool, }, + /// Capture nputrace. Subcommand functions aligned with Ascend Torch Profiler. + Nputrace { + /// Job id of the application to trace. + #[clap(long, default_value_t = 0)] + job_id: u64, + /// List of pids to capture trace for (comma separated). + #[clap(long, default_value = "0")] + pids: String, + /// Duration of trace to collect in ms. + #[clap(long, default_value_t = 500)] + duration_ms: u64, + /// Training iterations to collect, this takes precedence over duration. + #[clap(long, default_value_t = -1)] + iterations: i64, + /// Log file for trace. + #[clap(long)] + log_file: String, + /// Unix timestamp used for synchronized collection (milliseconds since epoch). + #[clap(long, default_value_t = 0)] + profile_start_time: u64, + /// Number of steps to start profile. + #[clap(long, default_value_t = 0)] + start_step: u64, + /// Max number of processes to profile. + #[clap(long, default_value_t = 3)] + process_limit: u32, + /// Whether to record PyTorch operator input shapes and types. + #[clap(long, action)] + record_shapes: bool, + /// Whether to profile PyTorch memory. + #[clap(long, action)] + profile_memory: bool, + /// Whether to profile the Python call stack in trace. + #[clap(long, action)] + with_stack: bool, + /// Annotate operators with analytical flops. + #[clap(long, action)] + with_flops: bool, + /// Whether to profile PyTorch operator modules in traces. + #[clap(long, action)] + with_modules: bool, + /// The scope of the profile's events. + #[clap(long, value_parser = ["CPU,NPU", "NPU,CPU", "CPU", "NPU"], default_value = "CPU,NPU")] + activities: String, + /// Profiler level. + #[clap(long, value_parser = ["Level0", "Level1", "Level2", "Level_none"], default_value = "Level0")] + profiler_level: String, + /// AIC metrics. + #[clap(long, value_parser = ["AiCoreNone", "PipeUtilization", "ArithmeticUtilization", "Memory", "MemoryL0", "ResourceConflictRatio", "MemoryUB", "L2Cache", "MemoryAccess"], default_value = "AiCoreNone")] + aic_metrics: String, + /// Whether to analyse the data after collection. + #[clap(long, action)] + analyse: bool, + /// Whether to collect L2 cache. + #[clap(long, action)] + l2_cache: bool, + /// Whether to collect op attributes. + #[clap(long, action)] + op_attr: bool, + /// Whether to enable MSTX. + #[clap(long, action)] + msprof_tx: bool, + /// GC detect threshold. + #[clap(long)] + gc_detect_threshold: Option, + /// Whether to streamline data after analyse is complete. + #[clap(long, value_parser = ["true", "false"], default_value = "true")] + data_simplification: String, + /// Types of data exported by the profiler. + #[clap(long, value_parser = ["Text", "Db"], default_value = "Text")] + export_type: String, + /// Obtain the system data on the host side. + #[clap(long, value_parser = parse_host_sys, default_value = "None")] + host_sys: String, + /// Whether to enable sys io. + #[clap(long, action)] + sys_io: bool, + /// Whether to enable sys interconnection. + #[clap(long, action)] + sys_interconnection: bool, + /// The domain that needs to be enabled in mstx mode. + #[clap(long)] + mstx_domain_include: Option, + /// Domains that do not need to be enabled in mstx mode. + #[clap(long)] + mstx_domain_exclude: Option, + }, + /// Ascend MSPTI Monitor + NpuMonitor { + /// Start NPU monitor. + #[clap(long, action)] + npu_monitor_start: bool, + /// Stop NPU monitor. + #[clap(long, action)] + npu_monitor_stop: bool, + /// NPU monitor report interval in seconds. + #[clap(long, default_value_t = 60)] + report_interval_s: u32, + /// MSPTI collect activity kind + #[clap(long, value_parser = parse_mspti_activity_kinds, default_value = "Marker")] + mspti_activity_kind: String, + }, /// Pause dcgm profiling. This enables running tools like Nsight compute and avoids conflicts. DcgmPause { /// Duration to pause dcgm profiling in seconds @@ -102,29 +264,397 @@ enum Command { DcgmResume, } -/// Create a socket connection to dynolog -fn create_dyno_client(host: &str, port: u16) -> Result { +struct ClientConfigPath { + cert_path: PathBuf, + key_path: PathBuf, + ca_cert_path: PathBuf, +} + +fn verify_certificate(cert_der: &[u8], is_root_cert: bool) -> Result<()> { + // 解析 X509 证书 + let (_, cert) = X509Certificate::from_der(cert_der) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse cert: {:?}", e)))?; + + // 检查证书版本是否为 X.509v3 + if cert.tbs_certificate.version != X509Version(2) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Certificate is not X.509v3" + ).into()); + } + + // 检查证书签名算法 + let sig_alg = cert.signature_algorithm.algorithm; + + // 定义不安全的算法 OID + let md2_rsa = oid!(1.2.840.113549.1.1.2); // MD2 with RSA + let md5_rsa = oid!(1.2.840.113549.1.1.4); // MD5 with RSA + let sha1_rsa = oid!(1.2.840.113549.1.1.5); // SHA1 with RSA + + // 检查是否使用不安全的算法 + if sig_alg == md2_rsa || sig_alg == md5_rsa || sig_alg == sha1_rsa { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Certificate uses insecure signature algorithm" + ).into()); + } + + // 定义 RSA 签名算法 OID + let rsa_sha256 = oid!(1.2.840.113549.1.1.11); // RSA with SHA256 + let rsa_sha384 = oid!(1.2.840.113549.1.1.12); // RSA with SHA384 + let rsa_sha512 = oid!(1.2.840.113549.1.1.13); // RSA with SHA512 + + // 检查 RSA 密钥长度 + if sig_alg == rsa_sha256 || sig_alg == rsa_sha384 || sig_alg == rsa_sha512 { + // 获取公钥 + if let Ok((_, public_key)) = SubjectPublicKeyInfo::from_der(&cert.tbs_certificate.subject_pki.subject_public_key.data) { + if let Ok((_, rsa_key)) = RSAPublicKey::from_der(&public_key.subject_public_key.data) { + // 检查 RSA 密钥长度 + let modulus = BigUint::from_bytes_be(&rsa_key.modulus); + let key_length = modulus.bits(); + if key_length < MIN_RSA_KEY_LENGTH { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("RSA key length {} bits is less than required {} bits", key_length, MIN_RSA_KEY_LENGTH) + ).into()); + } + } + } + } + + // 检查证书的扩展域 + let mut has_ca_constraint = false; + let mut has_key_usage = false; + let mut has_crl_sign = false; + let mut has_cert_sign = false; + + for ext in cert.tbs_certificate.extensions() { + if ext.oid == oid_registry::OID_X509_EXT_BASIC_CONSTRAINTS { + if let Ok((_, constraints)) = BasicConstraints::from_der(ext.value) { + has_ca_constraint = constraints.ca; + } else { + println!("Failed to parse Basic Constraints"); + } + } else if ext.oid == oid_registry::OID_X509_EXT_KEY_USAGE { + println!("Found Key Usage extension"); + if let Ok((_, usage)) = KeyUsage::from_der(ext.value) { + has_key_usage = true; + has_cert_sign = usage.key_cert_sign(); + has_crl_sign = usage.crl_sign(); + } else { + println!("Failed to parse Key Usage"); + } + } + } + + // 根据证书类型进行不同的验证 + if is_root_cert { + // 根证书验证要求 + if !has_ca_constraint { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Root certificate must have CA constraint" + ).into()); + } + if !has_key_usage { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Root certificate must have key usage extension" + ).into()); + } + if !has_cert_sign { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Root certificate must have certificate signature permission" + ).into()); + } + if !has_crl_sign { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Root certificate must have CRL signature permission" + ).into()); + } + } else { + // 客户端证书验证要求 + if has_ca_constraint { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Client certificate should not have CA constraint" + ).into()); + } + if !has_key_usage { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Client certificate must have key usage extension" + ).into()); + } + } + + // 检查证书有效期 + let now = chrono::Utc::now(); + let not_before = chrono::DateTime::from_timestamp( + cert.tbs_certificate.validity.not_before.timestamp(), + 0 + ).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid not_before date"))?; + + let not_after = chrono::DateTime::from_timestamp( + cert.tbs_certificate.validity.not_after.timestamp(), + 0 + ).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid not_after date"))?; + + if now < not_before { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Certificate is not yet valid. Valid from: {}", not_before) + ).into()); + } + + if now > not_after { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Certificate has expired. Expired at: {}", not_after) + ).into()); + } + + Ok(()) +} + +fn is_cert_revoked(cert_der: &[u8], crl_path: &PathBuf) -> Result { + // 解析 X509 证书 + let (_, cert) = X509Certificate::from_der(cert_der) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse cert: {:?}", e)))?; + + // 读取 CRL 文件 + let crl_data = read_to_string(crl_path)?; + let (_, pem) = pem::parse_x509_pem(crl_data.as_bytes()) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse CRL PEM: {:?}", e)))?; + + // 解析 CRL + let (_, crl) = CertificateRevocationList::from_der(&pem.contents) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse CRL: {:?}", e)))?; + + // 检查 CRL 的有效期 + let now = chrono::Utc::now(); + let crl_not_before = chrono::DateTime::from_timestamp( + crl.tbs_cert_list.this_update.timestamp(), + 0 + ).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid CRL this_update date"))?; + + let crl_not_after = if let Some(next_update) = crl.tbs_cert_list.next_update { + chrono::DateTime::from_timestamp( + next_update.timestamp(), + 0 + ).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid CRL next_update date"))? + } else { + crl_not_before + chrono::Duration::days(365) + }; + + // 检查 CRL 是否在有效期内 + if now < crl_not_before { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("CRL is not yet valid. Valid from: {}", crl_not_before) + ).into()); + } + + if now > crl_not_after { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("CRL has expired. Expired at: {}", crl_not_after) + ).into()); + } + + // 获取证书序列号 + let cert_serial = cert.tbs_certificate.serial.to_bigint() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Failed to convert certificate serial to BigInt"))?; + + // 检查 CRL 吊销条目 + for revoked in crl.iter_revoked_certificates() { + let revoked_serial = revoked.user_certificate.to_bigint() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Failed to convert revoked certificate serial to BigInt"))?; + + if revoked_serial == cert_serial { + return Ok(true); + } + } + Ok(false) +} + +enum DynoClient { + Secure(StreamOwned), + Insecure(TcpStream), +} + +fn create_dyno_client( + host: &str, + port: u16, + certs_dir: &str, +) -> Result { + if certs_dir == "NO_CERTS" { + println!("Running in no-certificate mode"); + create_dyno_client_with_no_certs(host, port) + } else { + println!("Running in certificate mode"); + let certs_dir = PathBuf::from(certs_dir); + let config = ClientConfigPath { + cert_path: certs_dir.join("client.crt"), + key_path: certs_dir.join("client.key"), + ca_cert_path: certs_dir.join("ca.crt"), + }; + let client = create_dyno_client_with_certs(host, port, &config)?; + Ok(DynoClient::Secure(client)) + } +} + +fn create_dyno_client_with_no_certs( + host: &str, + port: u16, +) -> Result { let addr = (host, port) .to_socket_addrs()? .next() .expect("Failed to connect to the server"); + let stream = TcpStream::connect(addr)?; + Ok(DynoClient::Insecure(stream)) +} + +fn create_dyno_client_with_certs( + host: &str, + port: u16, + config: &ClientConfigPath, +) -> Result> { + let addr = (host, port) + .to_socket_addrs()? + .next() + .ok_or_else(|| io::Error::new( + io::ErrorKind::NotFound, + "Could not resolve the host address" + ))?; + + let stream = TcpStream::connect(addr)?; + + println!("Loading CA cert from: {}", config.ca_cert_path.display()); + let mut root_store = RootCertStore::empty(); + let ca_file = File::open(&config.ca_cert_path)?; + let mut ca_reader = BufReader::new(ca_file); + let ca_certs = rustls_pemfile::certs(&mut ca_reader)?; + for ca_cert in &ca_certs { + verify_certificate(ca_cert, true)?; // 验证根证书 + } + for ca_cert in ca_certs { + root_store.add(&Certificate(ca_cert))?; + } + + println!("Loading client cert from: {}", config.cert_path.display()); + let cert_file = File::open(&config.cert_path)?; + let mut cert_reader = BufReader::new(cert_file); + let certs = rustls_pemfile::certs(&mut cert_reader)?; + + // 检查客户端证书的基本要求 + for cert in &certs { + verify_certificate(cert, false)?; // 验证客户端证书 + } + + // 检查证书吊销状态 + let crl_path = config.cert_path.parent().unwrap().join("ca.crl"); + if crl_path.exists() { + println!("Checking CRL file: {}", crl_path.display()); + for cert in &certs { + match is_cert_revoked(cert, &crl_path) { + Ok(true) => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Certificate is revoked" + ).into()); + } + Ok(false) => { + continue; + } + Err(e) => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("CRL verification failed: {}", e) + ).into()); + } + } + } + } else { + println!("CRL file does not exist: {}", crl_path.display()); + } - TcpStream::connect(addr).map_err(|err| err.into()) + let certs = certs.into_iter().map(Certificate).collect(); + + println!("Loading client key from: {}", config.key_path.display()); + let key_file = File::open(&config.key_path)?; + let mut key_reader = BufReader::new(key_file); + + // 检查私钥是否加密 + let mut key_data = Vec::new(); + key_reader.read_to_end(&mut key_data)?; + let key_str = String::from_utf8_lossy(&key_data); + let is_encrypted = key_str.contains("ENCRYPTED"); + + // 根据是否加密来加载私钥 + let keys = if is_encrypted { + // 如果私钥是加密的,请求用户输入密码 + let mut password = prompt_password("Please enter the certificate password: ")?; + let pkey = PKey::private_key_from_pem_passphrase(&key_data, password.as_bytes()) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to decrypt private key: {}", e)))?; + + // 清除密码 + password.clear(); + + // 返回私钥 + vec![pkey.private_key_to_der() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to convert private key to DER: {}", e)))?] + } else { + // 如果私钥未加密,直接加载 + let mut key_reader = BufReader::new(File::open(&config.key_path)?); + rustls_pemfile::pkcs8_private_keys(&mut key_reader)? + }; + + if keys.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "No private key found in the key file" + ).into()); + } + let key = PrivateKey(keys[0].clone()); + + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_client_auth_cert(certs, key)?; + + let server_name = rustls::ServerName::try_from(host) + .map_err(|e| io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid hostname: {}", e) + ))?; + + let conn = rustls::ClientConnection::new( + Arc::new(config), + server_name + )?; + + Ok(StreamOwned::new(conn, stream)) } + fn main() -> Result<()> { let Opts { hostname, port, + certs_dir, cmd, } = Opts::parse(); - let dyno_client = - create_dyno_client(&hostname, port).expect("Couldn't connect to the server..."); + let client = create_dyno_client(&hostname, port, &certs_dir) + .expect("Couldn't connect to the server..."); match cmd { - Command::Status => status::run_status(dyno_client), - Command::Version => version::run_version(dyno_client), + Command::Status => status::run_status(client), + Command::Version => version::run_version(client), Command::Gputrace { job_id, pids, @@ -163,10 +693,95 @@ fn main() -> Result<()> { trigger_config, trace_options, }; - gputrace::run_gputrace(dyno_client, job_id, &pids, process_limit, trace_config) + gputrace::run_gputrace(client, job_id, &pids, process_limit, trace_config) } - Command::DcgmPause { duration_s } => dcgm::run_dcgm_pause(dyno_client, duration_s), - Command::DcgmResume => dcgm::run_dcgm_resume(dyno_client), + Command::Nputrace { + job_id, + pids, + log_file, + duration_ms, + iterations, + profile_start_time, + start_step, + process_limit, + record_shapes, + profile_memory, + with_stack, + with_flops, + with_modules, + activities, + analyse, + profiler_level, + aic_metrics, + l2_cache, + op_attr, + msprof_tx, + gc_detect_threshold, + data_simplification, + export_type, + host_sys, + sys_io, + sys_interconnection, + mstx_domain_include, + mstx_domain_exclude, + } => { + let trigger_config = if iterations > 0 { + NpuTraceTriggerConfig::IterationBased { + start_step, + iterations, + } + } else { + NpuTraceTriggerConfig::DurationBased { + profile_start_time, + duration_ms, + } + }; + + let trace_options = NpuTraceOptions { + record_shapes, + profile_memory, + with_stack, + with_flops, + with_modules, + activities, + analyse, + profiler_level, + aic_metrics, + l2_cache, + op_attr, + msprof_tx, + gc_detect_threshold, + data_simplification, + export_type, + host_sys, + sys_io, + sys_interconnection, + mstx_domain_include, + mstx_domain_exclude, + }; + let trace_config = NpuTraceConfig { + log_file, + trigger_config, + trace_options, + }; + nputrace::run_nputrace(client, job_id, &pids, process_limit, trace_config) + } + Command::NpuMonitor { + npu_monitor_start, + npu_monitor_stop, + report_interval_s, + mspti_activity_kind, + } => { + let npu_mon_config = NpuMonitorConfig { + npu_monitor_start, + npu_monitor_stop, + report_interval_s, + mspti_activity_kind + }; + npumonitor::run_npumonitor(client, npu_mon_config) + } + Command::DcgmPause { duration_s } => dcgm::run_dcgm_pause(client, duration_s), + Command::DcgmResume => dcgm::run_dcgm_resume(client), // ... add new commands here } -} +} \ No newline at end of file diff --git a/msmonitor/dynolog_npu/dynolog/src/Main.cpp b/msmonitor/dynolog_npu/dynolog/src/Main.cpp index 41f224fc905..a24bb802ab4 100644 --- a/msmonitor/dynolog_npu/dynolog/src/Main.cpp +++ b/msmonitor/dynolog_npu/dynolog/src/Main.cpp @@ -15,6 +15,7 @@ #include "dynolog/src/KernelCollector.h" #include "dynolog/src/Logger.h" #include "dynolog/src/ODSJsonLogger.h" + #include "dynolog/src/PerfMonitor.h" #include "dynolog/src/ScubaLogger.h" #include "dynolog/src/ServiceHandler.h" @@ -28,6 +29,10 @@ #include "dynolog/src/PrometheusLogger.h" #endif +#ifdef USE_TENSORBOARD +#include "dynolog/src/DynologTensorBoardLogger.h" +#endif + using namespace dynolog; using json = nlohmann::json; namespace hbt = facebook::hbt; @@ -62,39 +67,47 @@ DEFINE_bool( "Enabled GPU monitorng, currently supports NVIDIA GPUs."); DEFINE_bool(enable_perf_monitor, false, "Enable heartbeat perf monitoring."); -std::unique_ptr getLogger(const std::string& scribe_category = "") { - std::vector> loggers; +std::unique_ptr getLogger(const std::string& scribe_category = "") +{ + std::vector> loggers; #ifdef USE_PROMETHEUS - if (FLAGS_use_prometheus) { + if (FLAGS_use_prometheus) { loggers.push_back(std::make_unique()); - } + } #endif - if (FLAGS_use_fbrelay) { +#ifdef USE_TENSORBOARD + if (!FLAGS_metric_log_dir.empty()) { + loggers.push_back(std::make_unique(FLAGS_metric_log_dir)); + } +#endif + if (FLAGS_use_fbrelay) { loggers.push_back(std::make_unique()); - } - if (FLAGS_use_ODS) { + } + if (FLAGS_use_ODS) { loggers.push_back(std::make_unique()); - } - if (FLAGS_use_JSON) { + } + if (FLAGS_use_JSON) { loggers.push_back(std::make_unique()); - } - if (FLAGS_use_scuba && !scribe_category.empty()) { + } + if (FLAGS_use_scuba && !scribe_category.empty()) { loggers.push_back(std::make_unique(scribe_category)); - } - return std::make_unique(std::move(loggers)); + } + return std::make_unique(std::move(loggers)); } -auto next_wakeup(int sec) { - return std::chrono::steady_clock::now() + std::chrono::seconds(sec); +auto next_wakeup(int sec) +{ + return std::chrono::steady_clock::now() + std::chrono::seconds(sec); } -void kernel_monitor_loop() { - KernelCollector kc; +void kernel_monitor_loop() +{ + KernelCollector kc; - LOG(INFO) << "Running kernel monitor loop : interval = " + LOG(INFO) << "Running kernel monitor loop : interval = " << FLAGS_kernel_monitor_reporting_interval_s << " s."; - while (1) { + while (1) { auto logger = getLogger(); auto wakeup_timepoint = next_wakeup(FLAGS_kernel_monitor_reporting_interval_s); @@ -105,20 +118,21 @@ void kernel_monitor_loop() { /* sleep override */ std::this_thread::sleep_until(wakeup_timepoint); - } + } } -void perf_monitor_loop() { - PerfMonitor pm( +void perf_monitor_loop() +{ + PerfMonitor pm( hbt::CpuSet::makeAllOnline(), std::vector{"instructions", "cycles"}, getDefaultPmuDeviceManager(), getDefaultMetrics()); - LOG(INFO) << "Running perf monitor loop : interval = " - << FLAGS_perf_monitor_reporting_interval_s << " s."; + LOG(INFO) << "Running perf monitor loop : interval = " + << FLAGS_perf_monitor_reporting_interval_s << " s."; - while (1) { + while (1) { auto logger = getLogger(); auto wakeup_timepoint = next_wakeup(FLAGS_perf_monitor_reporting_interval_s); @@ -129,22 +143,24 @@ void perf_monitor_loop() { logger->finalize(); /* sleep override */ std::this_thread::sleep_until(wakeup_timepoint); - } + } } -auto setup_server(std::shared_ptr handler) { - return std::make_unique>( - handler, FLAGS_port); +auto setup_server(std::shared_ptr handler) +{ + return std::make_unique>( + handler, FLAGS_port); } -void gpu_monitor_loop(std::shared_ptr dcgm) { - auto logger = getLogger(FLAGS_scribe_category); +void gpu_monitor_loop(std::shared_ptr dcgm) +{ + auto logger = getLogger(FLAGS_scribe_category); - LOG(INFO) << "Running DCGM loop : interval = " - << FLAGS_dcgm_reporting_interval_s << " s."; - LOG(INFO) << "DCGM fields: " << gpumon::FLAGS_dcgm_fields; + LOG(INFO) << "Running DCGM loop : interval = " + << FLAGS_dcgm_reporting_interval_s << " s."; + LOG(INFO) << "DCGM fields: " << gpumon::FLAGS_dcgm_fields; - while (1) { + while (1) { auto wakeup_timepoint = next_wakeup(FLAGS_dcgm_reporting_interval_s); dcgm->update(); @@ -152,55 +168,65 @@ void gpu_monitor_loop(std::shared_ptr dcgm) { /* sleep override */ std::this_thread::sleep_until(wakeup_timepoint); - } + } } -int main(int argc, char** argv) { - gflags::ParseCommandLineFlags(&argc, &argv, true); - FLAGS_logtostderr = 1; - google::InitGoogleLogging(argv[0]); +int main(int argc, char** argv) +{ + gflags::ParseCommandLineFlags(&argc, &argv, true); + FLAGS_logtostderr = 1; + google::InitGoogleLogging(argv[0]); - LOG(INFO) << "Starting dynolog, version = " DYNOLOG_VERSION - << ", build git-hash = " DYNOLOG_GIT_REV; + LOG(INFO) << "Starting Ascend Extension for dynolog, version = " DYNOLOG_VERSION + << ", build git-hash = " DYNOLOG_GIT_REV; - std::shared_ptr dcgm; + std::shared_ptr dcgm; - std::unique_ptr ipcmon; - std::unique_ptr ipcmon_thread, gpumon_thread, pm_thread; + std::unique_ptr ipcmon; + std::unique_ptr ipcmon_thread; + std::unique_ptr data_ipcmon_thread; + std::unique_ptr gpumon_thread; + std::unique_ptr pm_thread; - if (FLAGS_enable_ipc_monitor) { + if (FLAGS_enable_ipc_monitor) { LOG(INFO) << "Starting IPC Monitor"; ipcmon = std::make_unique(); + ipcmon->setLogger(std::move(getLogger())); ipcmon_thread = std::make_unique([&ipcmon]() { ipcmon->loop(); }); - } + data_ipcmon_thread = + std::make_unique([&ipcmon]() { ipcmon->dataLoop(); }); + } - if (FLAGS_enable_gpu_monitor) { + if (FLAGS_enable_gpu_monitor) { dcgm = gpumon::DcgmGroupInfo::factory( gpumon::FLAGS_dcgm_fields, FLAGS_dcgm_reporting_interval_s * 1000); gpumon_thread = std::make_unique(gpu_monitor_loop, dcgm); - } - std::thread km_thread{kernel_monitor_loop}; - if (FLAGS_enable_perf_monitor) { + } + std::thread km_thread{kernel_monitor_loop}; + if (FLAGS_enable_perf_monitor) { pm_thread = std::make_unique(perf_monitor_loop); - } + } + + // setup service + auto handler = std::make_shared(dcgm); - // setup service - auto handler = std::make_shared(dcgm); + // use simple json RPC server for now + auto server = setup_server(handler); + server->run(); - // use simple json RPC server for now - auto server = setup_server(handler); - server->run(); + if (km_thread.joinable()) { + km_thread.join(); + } - km_thread.join(); - if (pm_thread) { + if (pm_thread && pm_thread->joinable()) { pm_thread->join(); - } - if (gpumon_thread) { + } + if (gpumon_thread && gpumon_thread->joinable()) { gpumon_thread->join(); - } + } - server->stop(); + server->stop(); - return 0; + return 0; } diff --git a/msmonitor/dynolog_npu/dynolog/src/Metrics.cpp b/msmonitor/dynolog_npu/dynolog/src/Metrics.cpp index dd3327b44b5..959bb1778ce 100644 --- a/msmonitor/dynolog_npu/dynolog/src/Metrics.cpp +++ b/msmonitor/dynolog_npu/dynolog/src/Metrics.cpp @@ -10,90 +10,30 @@ namespace dynolog { -const std::vector getAllMetrics() { - static std::vector metrics_ = { - {.name = "cpu_util", - .type = MetricType::Ratio, - .desc = "Fraction of total CPU time spend on user or system mode."}, - {.name = "cpu_u", - .type = MetricType::Ratio, - .desc = "Fraction of total CPU time spent in user mode."}, - {.name = "cpu_s", - .type = MetricType::Ratio, - .desc = "Fraction of total CPU time spent in system mode."}, - {.name = "cpu_i", - .type = MetricType::Ratio, - .desc = "Fraction of total CPU time spent in idle mode."}, - {.name = "mips", - .type = MetricType::Ratio, - .desc = "Number of million instructions executed per second."}, - {.name = "mega_cycles_per_second", - .type = MetricType::Ratio, - .desc = "Number of active CPU clock cycles per second."}, - {.name = "uptime", - .type = MetricType::Instant, - .desc = "How long the system has been running in seconds."}, - }; - static std::map cpustats = { - {"cpu_u_ms", "user"}, - {"cpu_s_ms", "system"}, - {"cpu_n_ms", "nice"}, - {"cpu_w_ms", "iowait"}, - {"cpu_x_ms", "irq"}, - {"cpu_y_ms", "softirq"}, - }; - - auto metrics = metrics_; - - for (const auto& [name, cpu_mode] : cpustats) { - MetricDesc m{ - .name = name, - .type = MetricType::Delta, - .desc = fmt::format( - "Total CPU time in milliseconds spent in {} mode. " - "For more details please see man page for /proc/stat", - cpu_mode)}; - metrics.push_back(m); - } - - return metrics; +const std::vector getAllMetrics() +{ + static std::vector metrics_ = { + {.name = "kindName", + .type = MetricType::Instant, + .desc = "Report data kind name"}, + {.name = "duration", + .type = MetricType::Delta, + .desc = "Total execution time for corresponding kind"}, + {.name = "timestamp", + .type = MetricType::Instant, + .desc = "The timestamp of the reported data"}, + {.name = "deviceId", + .type = MetricType::Instant, + .desc = "The ID of the device for reporting data"}, + }; + return metrics_; } // These metrics are dynamic per network drive -const std::vector getNetworkMetrics() { - static std::vector metrics_ = { - {.name = "tx_bytes", - .type = MetricType::Delta, - .desc = - "Total bytes transmitted/received over the specific network device."}, - {.name = "rx_bytes", - .type = MetricType::Delta, - .desc = - "Total bytes transmitted/received over the specific network device."}, - {.name = "tx_packets", - .type = MetricType::Delta, - .desc = - "Total packets transmitted/received over the specific network device."}, - {.name = "rx_packets", - .type = MetricType::Delta, - .desc = - "Total packets transmitted/received over the specific network device."}, - {.name = "tx_errors", - .type = MetricType::Delta, - .desc = "Total transmit/receive errors on the specific network device."}, - {.name = "rx_errors", - .type = MetricType::Delta, - .desc = "Total transmit/receive errors on the specific network device."}, - {.name = "tx_drops", - .type = MetricType::Delta, - .desc = - "Total transmit/receive packet drops on the specific network device."}, - {.name = "rx_drops", - .type = MetricType::Delta, - .desc = - "Total transmit/receive packet drops on the specific network device."}, - }; - return metrics_; +const std::vector getNetworkMetrics() +{ + static std::vector metrics_ = {}; + return metrics_; } -} // namespace dynolog +} // namespace dynolog \ No newline at end of file diff --git a/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.cpp b/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.cpp index 579539f754e..dfe980a4013 100644 --- a/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.cpp +++ b/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.cpp @@ -10,76 +10,107 @@ #include #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +DEFINE_string(certs_dir, "", "TLS crets dir"); constexpr int CLIENT_QUEUE_LEN = 50; +const std::string NO_CERTS_MODE = "NO_CERTS"; +const size_t MIN_RSA_KEY_LENGTH = 3072; namespace dynolog { -SimpleJsonServerBase::SimpleJsonServerBase(int port) : port_(port) { - initSocket(); -} - -SimpleJsonServerBase::~SimpleJsonServerBase() { - if (thread_) { - stop(); - } - close(sock_fd_); -} - -void SimpleJsonServerBase::initSocket() { - struct sockaddr_in6 server_addr; - - /* Create socket for listening (client requests).*/ - sock_fd_ = ::socket(AF_INET6, SOCK_STREAM, 0); - if (sock_fd_ == -1) { - std::perror("socket()"); - return; - } - - /* Set socket to reuse address in case server is restarted.*/ - int flag = 1; - int ret = - ::setsockopt(sock_fd_, SOL_SOCKET, SO_REUSEADDR, &flag, sizeof(flag)); - if (ret == -1) { - std::perror("setsockopt()"); - return; - } - - // in6addr_any allows us to bind to both IPv4 and IPv6 clients. - server_addr.sin6_addr = in6addr_any; - server_addr.sin6_family = AF_INET6; - server_addr.sin6_port = htons(port_); - - /* Bind address and socket together */ - ret = ::bind(sock_fd_, (struct sockaddr*)&server_addr, sizeof(server_addr)); - if (ret == -1) { - std::perror("bind()"); - close(sock_fd_); - return; - } +void secure_clear_password(std::string& password); + +SimpleJsonServerBase::SimpleJsonServerBase(int port) : port_(port) +{ + try { + initSocket(); + if (FLAGS_certs_dir != NO_CERTS_MODE) { + init_openssl(); + ctx_ = create_context(); + configure_context(ctx_); + } + } catch (const std::exception& e) { + LOG(ERROR) << "Failed to initialize server: " << e.what(); + throw; + } +} - /* Create listening queue (client requests) */ - ret = ::listen(sock_fd_, CLIENT_QUEUE_LEN); - if (ret == -1) { - std::perror("listen()"); +SimpleJsonServerBase::~SimpleJsonServerBase() +{ + if (thread_) { + stop(); + } close(sock_fd_); - return; - } - - /* Get port if assigned 0 */ - if (port_ == 0) { - socklen_t len_out = sizeof(server_addr); - ret = ::getsockname(sock_fd_, (struct sockaddr*)&server_addr, &len_out); - if (ret < 0 || len_out != sizeof(server_addr)) { - std::perror("getsockname()"); - } else { - port_ = ntohs(server_addr.sin6_port); - LOG(INFO) << "System assigned port = " << ntohs(server_addr.sin6_port); + if (FLAGS_certs_dir != NO_CERTS_MODE && ctx_) { + SSL_CTX_free(ctx_); + } +} + +void SimpleJsonServerBase::initSocket() +{ + struct sockaddr_in6 server_addr; + + /* Create socket for listening (client requests). */ + sock_fd_ = ::socket(AF_INET6, SOCK_STREAM, 0); + if (sock_fd_ == -1) { + std::perror("socket()"); + return; + } + + /* Set socket to reuse address in case server is restarted. */ + int flag = 1; + int ret = + ::setsockopt(sock_fd_, SOL_SOCKET, SO_REUSEADDR, &flag, sizeof(flag)); + if (ret == -1) { + std::perror("setsockopt()"); + return; + } + + // in6addr_any allows us to bind to both IPv4 and IPv6 clients. + server_addr.sin6_addr = in6addr_any; + server_addr.sin6_family = AF_INET6; + server_addr.sin6_port = htons(port_); + + /* Bind address and socket together */ + ret = ::bind(sock_fd_, (struct sockaddr*)&server_addr, sizeof(server_addr)); + if (ret == -1) { + std::perror("bind()"); + close(sock_fd_); + return; } - } - LOG(INFO) << "Listening to connections on port " << port_; - initSuccess_ = true; + /* Create listening queue (client requests) */ + ret = ::listen(sock_fd_, CLIENT_QUEUE_LEN); + if (ret == -1) { + std::perror("listen()"); + close(sock_fd_); + return; + } + + /* Get port if assigned 0 */ + if (port_ == 0) { + socklen_t len_out = sizeof(server_addr); + ret = ::getsockname(sock_fd_, (struct sockaddr*)&server_addr, &len_out); + if (ret < 0 || len_out != sizeof(server_addr)) { + std::perror("getsockname()"); + } else { + port_ = ntohs(server_addr.sin6_port); + LOG(INFO) << "System assigned port = " << ntohs(server_addr.sin6_port); + } + } + + LOG(INFO) << "Listening to connections on port " << port_; + initSuccess_ = true; } /* A simple wrapper to accept connections and read data @@ -90,141 +121,655 @@ void SimpleJsonServerBase::initSocket() { * : char json[] */ class ClientSocketWrapper { - public: - ~ClientSocketWrapper() { - if (::close(client_sock_fd_) < 0) { - std::perror("close()"); - } - } - - bool accept(int server_socket) { - struct sockaddr_in6 client_addr; - socklen_t client_addr_len = sizeof(client_addr); - std::array client_addr_str; - - client_sock_fd_ = ::accept( - server_socket, (struct sockaddr*)&client_addr, &client_addr_len); - if (client_sock_fd_ == -1) { - std::perror("accept()"); - return false; - } - - inet_ntop( - AF_INET6, - &(client_addr.sin6_addr), - client_addr_str.data(), - client_addr_str.size()); - LOG(INFO) << "Received connection from " << client_addr_str.data(); - return true; - } - - /* Reads an entire message from the client, - * expected to return json*/ - std::string get_message() { - /* Wait for data from client */ - int32_t msg_size = -1; - // TODO use decltype or type alias - if (!read_helper((uint8_t*)&msg_size, sizeof(msg_size)) || msg_size <= 0) { - LOG(ERROR) << "Invalid message size = " << msg_size; - return ""; - } - - std::string message; - message.resize(msg_size); - int recv = 0; - - /* Wait for data from client */ - int ret = 1; - while (recv < msg_size && ret > 0) { - ret = read_helper((uint8_t*)&message[recv], msg_size - recv); - recv += ret > 0 ? ret : 0; - } - - if (recv != msg_size) { - LOG(ERROR) << "Received parital message, expected size " << msg_size - << "found : " << recv; - LOG(ERROR) << "Message received = " << message; - return ""; - } - - return message; - } - - bool send_response(const std::string& response) { - // send size prefixed response - int32_t size = response.size(); - int ret = ::write(client_sock_fd_, (void*)&size, sizeof(size)); - if (ret == -1) { - std::perror("read()"); +public: + ~ClientSocketWrapper() + { + if (FLAGS_certs_dir != NO_CERTS_MODE && ssl_) { + SSL_shutdown(ssl_); + SSL_free(ssl_); + } + if (client_sock_fd_ != -1) { + ::close(client_sock_fd_); + } } - int sent = 0; - while (sent < size && ret > 0) { - ret = ::write(client_sock_fd_, (void*)&response[sent], size - sent); - if (ret == -1) { - std::perror("read()"); - } else { - sent += ret; - } + bool accept(int server_socket, SSL_CTX* ctx) + { + struct sockaddr_in6 client_addr; + socklen_t client_addr_len = sizeof(client_addr); + std::array client_addr_str; + + client_sock_fd_ = ::accept( + server_socket, (struct sockaddr*)&client_addr, &client_addr_len); + if (client_sock_fd_ == -1) { + std::perror("accept()"); + return false; + } + + inet_ntop( + AF_INET6, + &(client_addr.sin6_addr), + client_addr_str.data(), + client_addr_str.size()); + LOG(INFO) << "Received connection from " << client_addr_str.data(); + + if (FLAGS_certs_dir == NO_CERTS_MODE) { + LOG(INFO) << "No certs mode"; + return true; + } + + ssl_ = SSL_new(ctx); + SSL_set_fd(ssl_, client_sock_fd_); + if (SSL_accept(ssl_) <= 0) { + ERR_print_errors_fp(stderr); + return false; + } + LOG(INFO) << "SSL handshake success"; + return true; } - if (sent < response.size()) { - LOG(ERROR) << "Unable to write full response"; - return false; + std::string get_message() + { + int32_t msg_size = -1; + if (!read_helper((uint8_t*)&msg_size, sizeof(msg_size)) || msg_size <= 0) { + LOG(ERROR) << "Invalid message size = " << msg_size; + return ""; + } + + std::string message; + message.resize(msg_size); + int recv = 0; + int ret = 1; + while (recv < msg_size && ret > 0) { + ret = read_helper((uint8_t*)&message[recv], msg_size - recv); + recv += ret > 0 ? ret : 0; + } + + if (recv != msg_size) { + LOG(ERROR) << "Received partial message, expected size " << msg_size + << " found : " << recv; + LOG(ERROR) << "Message received = " << message; + return ""; + } + + return message; } - return ret > 0; - } - private: - int read_helper(uint8_t* buf, int size) { - int ret = ::read(client_sock_fd_, (void*)buf, size); - if (ret == -1) { - std::perror("read()"); + bool send_response(const std::string& response) + { + int32_t size = response.size(); + int ret; + if (FLAGS_certs_dir == NO_CERTS_MODE) { + ret = ::write(client_sock_fd_, (void*)&size, sizeof(size)); + if (ret == -1) { + std::perror("write()"); + return false; + } + } else { + ret = SSL_write(ssl_, (void*)&size, sizeof(size)); + if (ret <= 0) { + ERR_print_errors_fp(stderr); + return false; + } + } + int sent = 0; + while (sent < size && ret > 0) { + if (FLAGS_certs_dir == NO_CERTS_MODE) { + ret = ::write(client_sock_fd_, (void*)&response[sent], size - sent); + if (ret == -1) { + std::perror("write()"); + } else { + sent += ret; + } + } else { + ret = SSL_write(ssl_, (void*)&response[sent], size - sent); + if (ret <= 0) { + ERR_print_errors_fp(stderr); + } else { + sent += ret; + } + } + } + + if (sent < response.size()) { + LOG(ERROR) << "Unable to write full response"; + return false; + } + return ret > 0; } - return ret; - } - int client_sock_fd_ = -1; +private: + int read_helper(uint8_t* buf, int size) + { + if (FLAGS_certs_dir == NO_CERTS_MODE) { + int ret = ::read(client_sock_fd_, (void*)buf, size); + if (ret == -1) { + std::perror("read()"); + } + return ret; + } + int ret = SSL_read(ssl_, (void*)buf, size); + if (ret <= 0) { + ERR_print_errors_fp(stderr); + } + return ret; + } + int client_sock_fd_ = -1; + SSL* ssl_ = nullptr; }; /* Accepts socket connections and processes the payloads. - * This will inturn call the Handler functions*/ -void SimpleJsonServerBase::loop() noexcept { - if (sock_fd_ == -1 || !initSuccess_) { - return; - } + * This will inturn call the Handler functions */ +void SimpleJsonServerBase::loop() noexcept +{ + if (sock_fd_ == -1 || !initSuccess_) { + return; + } + + while (run_) { + processOne(); + } +} + +void SimpleJsonServerBase::processOne() noexcept +{ + LOG(INFO) << "Waiting for connection."; + ClientSocketWrapper client; + if (!client.accept(sock_fd_, ctx_)) { + return; + } + std::string request_str = client.get_message(); + LOG(INFO) << "RPC message received = " << request_str; + auto response_str = processOneImpl(request_str); + if (response_str.empty()) { + return; + } + if (!client.send_response(response_str)) { + LOG(ERROR) << "Failed to send response"; + } +} + +void SimpleJsonServerBase::run() +{ + LOG(INFO) << "Launching RPC thread"; + thread_ = std::make_unique([this]() { this->loop(); }); +} + +void SimpleJsonServerBase::init_openssl() +{ + SSL_load_error_strings(); + OpenSSL_add_ssl_algorithms(); +} + +SSL_CTX* SimpleJsonServerBase::create_context() +{ + const SSL_METHOD* method = TLS_server_method(); + SSL_CTX* ctx = SSL_CTX_new(method); + if (!ctx) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("Unable to create SSL context"); + } + return ctx; +} + +static bool is_cert_revoked(X509* cert, X509_STORE* store) +{ + if (!cert || !store) { + LOG(ERROR) << "Invalid certificate or store pointer"; + return false; + } + // 获取证书的颁发者名称 + X509_NAME* issuer = X509_get_issuer_name(cert); + if (!issuer) { + LOG(ERROR) << "Failed to get certificate issuer"; + return false; + } + // 获取证书的序列号 + const ASN1_INTEGER* serial = X509_get_serialNumber(cert); + if (!serial) { + LOG(ERROR) << "Failed to get certificate serial number"; + return false; + } + // 创建证书验证上下文 + X509_STORE_CTX* ctx = X509_STORE_CTX_new(); + if (!ctx) { + LOG(ERROR) << "Failed to create certificate store context"; + return false; + } + bool is_revoked = false; + try { + // 初始化证书验证上下文 + if (!X509_STORE_CTX_init(ctx, store, cert, nullptr)) { + LOG(ERROR) << "Failed to initialize certificate store context"; + X509_STORE_CTX_free(ctx); + return false; + } + // 获取CRL列表 + STACK_OF(X509_CRL)* crls = X509_STORE_CTX_get1_crls(ctx, issuer); + if (!crls) { + LOG(INFO) << "No CRLs found for issuer"; + X509_STORE_CTX_free(ctx); + return false; + } + time_t current_time = time(nullptr); + for (int i = 0; i < sk_X509_CRL_num(crls); i++) { + X509_CRL* crl = sk_X509_CRL_value(crls, i); + if (!crl) { + LOG(ERROR) << "Invalid CRL at index " << i; + continue; + } + // 检查 CRL 的有效期 + const ASN1_TIME* crl_this_update = X509_CRL_get0_lastUpdate(crl); + const ASN1_TIME* crl_next_update = X509_CRL_get0_nextUpdate(crl); + if (!crl_this_update) { + LOG(ERROR) << "Failed to get CRL this_update time"; + continue; + } + // 检查 CRL 是否已生效 + if (X509_cmp_time(crl_this_update, ¤t_time) > 0) { + LOG(INFO) << "CRL is not yet valid"; + continue; + } + // 检查 CRL 是否过期 + if (crl_next_update) { + if (X509_cmp_time(crl_next_update, ¤t_time) < 0) { + LOG(INFO) << "CRL has expired"; + continue; + } + } + // 检查证书是否在 CRL 中 + STACK_OF(X509_REVOKED)* revoked = X509_CRL_get_REVOKED(crl); + if (revoked) { + for (int j = 0; j < sk_X509_REVOKED_num(revoked); j++) { + X509_REVOKED* rev = sk_X509_REVOKED_value(revoked, j); + if (rev) { + const ASN1_INTEGER* rev_serial = X509_REVOKED_get0_serialNumber(rev); + if (rev_serial && ASN1_INTEGER_cmp(serial, rev_serial) == 0) { + LOG(INFO) << "Certificate is found in CRL"; + is_revoked = true; + break; + } + } + } + } + if (is_revoked) { + break; + } + } + if (crls) { + sk_X509_CRL_pop_free(crls, X509_CRL_free); + } + } catch (const std::exception& e) { + LOG(ERROR) << "Exception while checking CRL: " << e.what(); + is_revoked = false; + } + X509_STORE_CTX_free(ctx); + return is_revoked; +} + + +// 禁用终端回显的函数,但显示星号 +std::string get_password_with_stars() +{ + struct termios old_flags; + struct termios new_flags; + std::string password; + + // 获取当前终端设置 + tcgetattr(fileno(stdin), &old_flags); + new_flags = old_flags; + new_flags.c_lflag &= ~ECHO; // 禁用回显 + + // 设置新的终端属性 + tcsetattr(fileno(stdin), TCSANOW, &new_flags); + + // 读取密码并显示星号 + char ch; + while ((ch = getchar()) != '\n') { + if (ch == 127 || ch == 8) { // 处理退格键 (ASCII 127 或 8) + if (!password.empty()) { + password.pop_back(); + std::cout << "\b \b"; // 删除一个星号 + std::cout.flush(); // 立即刷新输出 + } + } else { + password += ch; + std::cout << '*'; // 显示星号 + std::cout.flush(); // 立即刷新输出 + } + } + + // 恢复原来的终端设置 + tcsetattr(fileno(stdin), TCSANOW, &old_flags); - while (run_) { - processOne(); - } + return password; } -void SimpleJsonServerBase::processOne() noexcept { - /* Wait for incoming connections, the accept call is blocking*/ - LOG(INFO) << "Waiting for connection."; - ClientSocketWrapper client; - if (!client.accept(sock_fd_)) { - return; - } +// 验证证书版本和签名算法 +void SimpleJsonServerBase::verify_certificate_version_and_algorithm(X509* cert) +{ + // 1. 检查证书版本是否为 X.509v3 + if (X509_get_version(cert) != 2) { // 2 表示 X.509v3 + throw std::runtime_error("Certificate is not X.509v3"); + } + + // 2. 检查证书签名算法 + const X509_ALGOR* sig_alg = X509_get0_tbs_sigalg(cert); + if (!sig_alg) { + throw std::runtime_error("Failed to get signature algorithm"); + } + + int sig_nid = OBJ_obj2nid(sig_alg->algorithm); + // 检查是否使用不安全的算法 + if (sig_nid == NID_md2WithRSAEncryption || + sig_nid == NID_md5WithRSAEncryption || + sig_nid == NID_sha1WithRSAEncryption) { + throw std::runtime_error("Certificate uses insecure signature algorithm: " + std::to_string(sig_nid)); + } +} + +// 验证 RSA 密钥长度 +void SimpleJsonServerBase::verify_rsa_key_length(EVP_PKEY* pkey) +{ + if (EVP_PKEY_base_id(pkey) == EVP_PKEY_RSA) { + size_t key_length = 0; +#if OPENSSL_VERSION_NUMBER >= 0x30000000L + // OpenSSL 3.0 及以上版本 + key_length = EVP_PKEY_get_size(pkey) * 8; // 转换为位数 +#else + // OpenSSL 1.1.1 及以下版本 + RSA* rsa = EVP_PKEY_get0_RSA(pkey); + if (!rsa) { + throw std::runtime_error("Failed to get RSA key"); + } + + const BIGNUM* n = nullptr; + RSA_get0_key(rsa, &n, nullptr, nullptr); + if (!n) { + throw std::runtime_error("Failed to get RSA modulus"); + } + + key_length = BN_num_bits(n); +#endif + if (key_length < MIN_RSA_KEY_LENGTH) { + throw std::runtime_error("RSA key length " + std::to_string(key_length) + " bits is less than required " + std::to_string(MIN_RSA_KEY_LENGTH) + " bits"); + } + } +} + +// 验证证书有效期 +void SimpleJsonServerBase::verify_certificate_validity(X509* cert) +{ + ASN1_TIME* not_before = X509_get_notBefore(cert); + ASN1_TIME* not_after = X509_get_notAfter(cert); + if (!not_before || !not_after) { + throw std::runtime_error("Failed to get certificate validity period"); + } + + time_t current_time = time(nullptr); + struct tm tm_before = {}; + struct tm tm_after = {}; + if (!ASN1_TIME_to_tm(not_before, &tm_before) || + !ASN1_TIME_to_tm(not_after, &tm_after)) { + throw std::runtime_error("Failed to convert certificate dates"); + } + + time_t not_before_time = mktime(&tm_before); + time_t not_after_time = mktime(&tm_after); + + // 检查证书是否已生效 + if (current_time < not_before_time) { + BIO* bio = BIO_new(BIO_s_mem()); + if (bio) { + ASN1_TIME_print(bio, not_before); + char* not_before_str = nullptr; + long len = BIO_get_mem_data(bio, ¬_before_str); + if (len > 0) { + std::string time_str(not_before_str, len); + BIO_free(bio); + throw std::runtime_error("Server certificate is not yet valid. Valid from: " + time_str); + } + BIO_free(bio); + } + throw std::runtime_error("Server certificate is not yet valid"); + } - std::string request_str = client.get_message(); - LOG(INFO) << "RPC message received = " << request_str; + // 检查证书是否已过期 + if (current_time > not_after_time) { + BIO* bio = BIO_new(BIO_s_mem()); + if (bio) { + ASN1_TIME_print(bio, not_after); + char* not_after_str = nullptr; + long len = BIO_get_mem_data(bio, ¬_after_str); + if (len > 0) { + std::string time_str(not_after_str, len); + BIO_free(bio); + throw std::runtime_error("Server certificate has expired. Expired at: " + time_str); + } + BIO_free(bio); + } + throw std::runtime_error("Server certificate has expired"); + } +} - auto response_str = processOneImpl(request_str); - if (response_str.empty()) { - return; - } +// 验证证书扩展域 +void SimpleJsonServerBase::verify_certificate_extensions(X509* cert) +{ + bool has_ca_constraint = false; + bool has_key_usage = false; + bool has_cert_sign = false; + bool has_crl_sign = false; + + const STACK_OF(X509_EXTENSION)* exts = X509_get0_extensions(cert); + if (exts) { + for (int i = 0; i < sk_X509_EXTENSION_num(exts); i++) { + X509_EXTENSION* ext = sk_X509_EXTENSION_value(exts, i); + ASN1_OBJECT* obj = X509_EXTENSION_get_object(ext); + + if (OBJ_obj2nid(obj) == NID_basic_constraints) { + BASIC_CONSTRAINTS* constraints = (BASIC_CONSTRAINTS*)X509V3_EXT_d2i(ext); + if (constraints) { + has_ca_constraint = constraints->ca; + BASIC_CONSTRAINTS_free(constraints); + } + } else if (OBJ_obj2nid(obj) == NID_key_usage) { + ASN1_BIT_STRING* usage = (ASN1_BIT_STRING*)X509V3_EXT_d2i(ext); + if (usage) { + has_key_usage = true; + has_cert_sign = (usage->data[0] & KU_KEY_CERT_SIGN) != 0; + has_crl_sign = (usage->data[0] & KU_CRL_SIGN) != 0; + ASN1_BIT_STRING_free(usage); + } + } + } + } - if (!client.send_response(response_str)) { - LOG(ERROR) << "Failed to send response"; - } + if (has_ca_constraint) { + throw std::runtime_error("Client certificate should not have CA constraint"); + } + if (!has_key_usage) { + throw std::runtime_error("Client certificate must have key usage extension"); + } +} - return; +// 加载私钥 +void SimpleJsonServerBase::load_private_key(SSL_CTX* ctx, const std::string& server_key) +{ + FILE* key_file = fopen(server_key.c_str(), "r"); + if (!key_file) { + throw std::runtime_error("Failed to open server key file"); + } + + bool is_encrypted = false; + char line[256]; + while (fgets(line, sizeof(line), key_file)) { + if (strstr(line, "ENCRYPTED")) { + is_encrypted = true; + break; + } + } + rewind(key_file); + + if (is_encrypted) { + std::string password; + std::cout << "Please enter the certificate password: "; + password = get_password_with_stars(); + std::cout << std::endl; + + EVP_PKEY* pkey = PEM_read_PrivateKey( + key_file, + nullptr, + [](char* buf, int size, int rwflag, void* userdata) -> int { + const std::string* password = static_cast(userdata); + if (password->size() > static_cast(size)) { + return 0; + } + password->copy(buf, password->size()); + return password->size(); + }, + const_cast(&password)); + + fclose(key_file); + secure_clear_password(password); + + if (!pkey) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("Failed to load encrypted server private key"); + } + + if (SSL_CTX_use_PrivateKey(ctx, pkey) <= 0) { + EVP_PKEY_free(pkey); + ERR_print_errors_fp(stderr); + throw std::runtime_error("Failed to use server private key"); + } + + EVP_PKEY_free(pkey); + } else { + fclose(key_file); + if (SSL_CTX_use_PrivateKey_file(ctx, server_key.c_str(), SSL_FILETYPE_PEM) <= 0) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("Failed to load server private key"); + } + } } -void SimpleJsonServerBase::run() { - LOG(INFO) << "Launching RPC thread"; - thread_ = std::make_unique([this]() { this->loop(); }); +// 加载和验证 CRL +void SimpleJsonServerBase::load_and_verify_crl(SSL_CTX* ctx, const std::string& crl_file) +{ + X509_STORE* store = SSL_CTX_get_cert_store(ctx); + if (!store) { + throw std::runtime_error("Failed to get certificate store"); + } + + if (access(crl_file.c_str(), F_OK) != -1) { + FILE* crl_file_ptr = fopen(crl_file.c_str(), "r"); + if (!crl_file_ptr) { + LOG(WARNING) << "Failed to open CRL file: " << crl_file; + return; + } + + X509_CRL* crl = PEM_read_X509_CRL(crl_file_ptr, nullptr, nullptr, nullptr); + fclose(crl_file_ptr); + + if (!crl) { + LOG(WARNING) << "Failed to read CRL from file: " << crl_file; + return; + } + + if (X509_STORE_add_crl(store, crl) != 1) { + LOG(WARNING) << "Failed to add CRL to certificate store"; + X509_CRL_free(crl); + return; + } + + X509* cert = SSL_CTX_get0_certificate(ctx); + if (!cert) { + X509_CRL_free(crl); + throw std::runtime_error("Failed to get server certificate"); + } + + if (is_cert_revoked(cert, store)) { + X509_CRL_free(crl); + throw std::runtime_error("Server certificate is revoked"); + } + + X509_CRL_free(crl); + } +} + +void SimpleJsonServerBase::configure_context(SSL_CTX* ctx) +{ + if (FLAGS_certs_dir.empty()) { + throw std::runtime_error("--certs-dir must be specified!"); + } + + std::string certs_dir = FLAGS_certs_dir; + if (!certs_dir.empty() && certs_dir.back() != '/') + certs_dir += '/'; + + std::string server_cert = certs_dir + "server.crt"; + std::string server_key = certs_dir + "server.key"; + std::string ca_cert = certs_dir + "ca.crt"; + std::string crl_file = certs_dir + "ca.crl"; + + LOG(INFO) << "Loading server cert: " << server_cert; + LOG(INFO) << "Loading server key: " << server_key; + LOG(INFO) << "Loading CA cert: " << ca_cert; + + // 1. 加载并验证服务器证书 + if (SSL_CTX_use_certificate_file(ctx, server_cert.c_str(), SSL_FILETYPE_PEM) <= 0) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("Failed to load server certificate"); + } + + X509* cert = SSL_CTX_get0_certificate(ctx); + if (!cert) { + throw std::runtime_error("Failed to get server certificate"); + } + + // 2. 验证证书版本和签名算法 + verify_certificate_version_and_algorithm(cert); + + // 3. 验证 RSA 密钥长度 + EVP_PKEY* pkey = X509_get_pubkey(cert); + if (!pkey) { + throw std::runtime_error("Failed to get public key"); + } + verify_rsa_key_length(pkey); + EVP_PKEY_free(pkey); + + // 4. 验证证书有效期 + verify_certificate_validity(cert); + + // 5. 验证证书扩展域 + verify_certificate_extensions(cert); + + // 6. 加载私钥 + load_private_key(ctx, server_key); + + // 7. 加载 CA 证书 + if (SSL_CTX_load_verify_locations(ctx, ca_cert.c_str(), NULL) <= 0) { + ERR_print_errors_fp(stderr); + throw std::runtime_error("Failed to load CA certificate"); + } + + // 8. 加载和验证 CRL + load_and_verify_crl(ctx, crl_file); + + // 9. 设置证书验证选项 + SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL); +} + +void secure_clear_password(std::string& password) +{ + if (!password.empty()) { + // 使用随机数据覆盖密码 + std::generate(password.begin(), password.end(), std::rand); + // 清空字符串 + password.clear(); + // 收缩字符串容量,释放内存 + password.shrink_to_fit(); + } } -} // namespace dynolog +} // namespace dynolog \ No newline at end of file diff --git a/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.h b/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.h index 52be6e4cf30..9efe0d086ca 100644 --- a/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.h +++ b/msmonitor/dynolog_npu/dynolog/src/rpc/SimpleJsonServer.h @@ -10,54 +10,71 @@ #include #include #include +#include +#include #include "dynolog/src/ServiceHandler.h" -namespace dynolog { +DECLARE_string(certs_dir); +namespace dynolog { // This is a simple service built using UNIX Sockets // with remote procedure calls implemented via JSON string. - class SimpleJsonServerBase { - public: - explicit SimpleJsonServerBase(int port); - - virtual ~SimpleJsonServerBase(); - - int getPort() const { - return port_; - } - - bool initSuccessful() const { - return initSuccess_; - } - // spin up a new thread to process requets - void run(); - - void stop() { - run_ = 0; - thread_->join(); - } - - // synchronously processes a request - void processOne() noexcept; - - protected: - void initSocket(); - - // process requests in a loop - void loop() noexcept; - - // implement processing of request using the handler - virtual std::string processOneImpl(const std::string& request_str) { - return ""; - } - - int port_; - int sock_fd_{-1}; - bool initSuccess_{false}; - - std::atomic run_{true}; - std::unique_ptr thread_; +public: + explicit SimpleJsonServerBase(int port); + virtual ~SimpleJsonServerBase(); + + int getPort() const + { + return port_; + } + + bool initSuccessful() const + { + return initSuccess_; + } + // spin up a new thread to process requets + void run(); + + void stop() + { + run_ = 0; + thread_->join(); + } + + // synchronously processes a request + void processOne() noexcept; + +protected: + void initSocket(); + void init_openssl(); + SSL_CTX *create_context(); + void configure_context(SSL_CTX *ctx); + + // process requests in a loop + void loop() noexcept; + + // implement processing of request using the handler + virtual std::string processOneImpl(const std::string &request_str) + { + return ""; + } + + void verify_certificate_version_and_algorithm(X509 *cert); + void verify_rsa_key_length(EVP_PKEY *pkey); + void verify_certificate_validity(X509 *cert); + void verify_certificate_extensions(X509 *cert); + void load_private_key(SSL_CTX *ctx, const std::string &server_key); + void load_and_verify_crl(SSL_CTX *ctx, const std::string &crl_file); + + int port_; + int sock_fd_{-1}; + bool initSuccess_{false}; + + std::atomic run_{true}; + std::unique_ptr thread_; + + SSL_CTX *ctx_{nullptr}; }; -} // namespace dynolog +} // namespace dynolog \ No newline at end of file diff --git a/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.cpp b/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.cpp index 8e5f0fe1757..2f419d3796a 100644 --- a/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.cpp +++ b/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.cpp @@ -5,7 +5,8 @@ #include "dynolog/src/tracing/IPCMonitor.h" #include -#include +#include +#include #include #include #include @@ -20,105 +21,169 @@ namespace dynolog { namespace tracing { constexpr int kSleepUs = 10000; +constexpr int kDataMsgSleepUs = 1000; const std::string kLibkinetoRequest = "req"; const std::string kLibkinetoContext = "ctxt"; +const std::string kLibkinetoData = "data"; -IPCMonitor::IPCMonitor(const std::string& ipc_fabric_name) { - ipc_manager_ = FabricManager::factory(ipc_fabric_name); - // below ensures singleton exists - LOG(INFO) << "Kineto config manager : active processes = " - << LibkinetoConfigManager::getInstance()->processCount("0"); +IPCMonitor::IPCMonitor(const std::string& ipc_fabric_name) +{ + ipc_manager_ = FabricManager::factory(ipc_fabric_name); + data_ipc_manager_ = FabricManager::factory(ipc_fabric_name + "_data"); + // below ensures singleton exists + LOG(INFO) << "Kineto config manager : active processes = " + << LibkinetoConfigManager::getInstance()->processCount("0"); } -void IPCMonitor::loop() { - while (ipc_manager_) { - if (ipc_manager_->recv()) { - std::unique_ptr msg = ipc_manager_->retrieve_msg(); - processMsg(std::move(msg)); +void IPCMonitor::loop() +{ + while (ipc_manager_) { + if (ipc_manager_->recv()) { + std::unique_ptr msg = ipc_manager_->retrieve_msg(); + processMsg(std::move(msg)); + } + /* sleep override */ + usleep(kSleepUs); } - /* sleep override */ - usleep(kSleepUs); - } } -void IPCMonitor::processMsg(std::unique_ptr msg) { - if (!ipc_manager_) { - LOG(ERROR) << "Fabric Manager not initialized"; - return; - } - // sizeof(msg->metadata.type) = 32, well above the size of the constant - // strings we are comparing against. memcmp is safe - if (memcmp( // NOLINT(facebook-security-vulnerable-memcmp) - msg->metadata.type, - kLibkinetoContext.data(), - kLibkinetoContext.size()) == 0) { - registerLibkinetoContext(std::move(msg)); - } else if ( - memcmp( // NOLINT(facebook-security-vulnerable-memcmp) - msg->metadata.type, - kLibkinetoRequest.data(), - kLibkinetoRequest.size()) == 0) { - getLibkinetoOnDemandRequest(std::move(msg)); - } else { - LOG(ERROR) << "TYPE UNKOWN: " << msg->metadata.type; - } +void IPCMonitor::dataLoop() +{ + while (data_ipc_manager_) { + if (data_ipc_manager_->recv()) { + std::unique_ptr msg = data_ipc_manager_->retrieve_msg(); + processDataMsg(std::move(msg)); + } + /* sleep override */ + usleep(kDataMsgSleepUs); + } +} + +void IPCMonitor::processMsg(std::unique_ptr msg) +{ + if (!ipc_manager_) { + LOG(ERROR) << "Fabric Manager not initialized"; + return; + } + // sizeof(msg->metadata.type) = 32, well above the size of the constant + // strings we are comparing against. memcmp is safe + if (memcmp( // NOLINT(facebook-security-vulnerable-memcmp) + msg->metadata.type, + kLibkinetoContext.data(), + kLibkinetoContext.size()) == 0) { + registerLibkinetoContext(std::move(msg)); + } else if ( + memcmp( // NOLINT(facebook-security-vulnerable-memcmp) + msg->metadata.type, + kLibkinetoRequest.data(), + kLibkinetoRequest.size()) == 0) { + getLibkinetoOnDemandRequest(std::move(msg)); + } else { + LOG(ERROR) << "TYPE UNKOWN: " << msg->metadata.type; + } +} + +void tracing::IPCMonitor::setLogger(std::unique_ptr logger) +{ + logger_ = std::move(logger); +} + +void IPCMonitor::LogData(const nlohmann::json& result) +{ + auto timestamp = result["timestamp"].get(); + logger_->logUint("timestamp", timestamp); + auto duration = result["duration"].get(); + logger_->logUint("duration", duration); + auto deviceId = result["deviceId"].get(); + logger_->logStr("deviceId", std::to_string(deviceId)); + auto kind = result["kind"].get(); + logger_->logStr("kind", kind); + if (result.contains("domain") && result["domain"].is_string()) { + auto domain = result["domain"].get(); + logger_->logStr("domain", domain); + } + logger_->finalize(); +} + +void IPCMonitor::processDataMsg(std::unique_ptr msg) +{ + if (!data_ipc_manager_) { + LOG(ERROR) << "Fabric Manager not initialized"; + return; + } + if (memcmp( // NOLINT(facebook-security-vulnerable-memcmp) + msg->metadata.type, + kLibkinetoData.data(), + kLibkinetoData.size()) == 0) { + std::string message = std::string((char*)msg->buf.get(), msg->metadata.size); + try { + nlohmann::json result = nlohmann::json::parse(message); + LOG(INFO) << "Received data message : " << result; + LogData(result); + } catch (nlohmann::json::parse_error&) { + LOG(ERROR) << "Error parsing message = " << message; + return; + } + } else { + LOG(ERROR) << "TYPE UNKOWN: " << msg->metadata.type; + } } void IPCMonitor::getLibkinetoOnDemandRequest( - std::unique_ptr msg) { - if (!ipc_manager_) { - LOG(ERROR) << "Fabric Manager not initialized"; - return; - } - std::string ret_config = ""; - ipcfabric::LibkinetoRequest* req = - (ipcfabric::LibkinetoRequest*)msg->buf.get(); - if (req->n == 0) { - LOG(ERROR) << "Missing pids parameter for type " << req->type; - return; - } - std::vector pids(req->pids, req->pids + req->n); - try { - ret_config = LibkinetoConfigManager::getInstance()->obtainOnDemandConfig( - std::to_string(req->jobid), pids, req->type); - VLOG(0) << "getLibkinetoOnDemandRequest() : job id " << req->jobid + std::unique_ptr msg) +{ + if (!ipc_manager_) { + LOG(ERROR) << "Fabric Manager not initialized"; + return; + } + std::string ret_config = ""; + ipcfabric::LibkinetoRequest* req = + (ipcfabric::LibkinetoRequest*)msg->buf.get(); + if (req->n == 0) { + LOG(ERROR) << "Missing pids parameter for type " << req->type; + return; + } + std::vector pids(req->pids, req->pids + req->n); + try { + ret_config = LibkinetoConfigManager::getInstance()->obtainOnDemandConfig( + std::to_string(req->jobid), pids, req->type); + VLOG(0) << "getLibkinetoOnDemandRequest() : job id " << req->jobid << " pids = " << pids[0]; - } catch (const std::runtime_error& ex) { - LOG(ERROR) << "Kineto config manager exception : " << ex.what(); - } - std::unique_ptr ret = - ipcfabric::Message::constructMessage( - ret_config, kLibkinetoRequest); - if (!ipc_manager_->sync_send(*ret, msg->src)) { - LOG(ERROR) << "Failed to return config to libkineto: IPC sync_send fail"; - } - - return; + } catch (const std::runtime_error& ex) { + LOG(ERROR) << "Kineto config manager exception : " << ex.what(); + } + std::unique_ptr ret = + ipcfabric::Message::constructMessage( + ret_config, kLibkinetoRequest); + if (!ipc_manager_->sync_send(*ret, msg->src)) { + LOG(ERROR) << "Failed to return config to libkineto: IPC sync_send fail"; + } + return; } void IPCMonitor::registerLibkinetoContext( - std::unique_ptr msg) { - if (!ipc_manager_) { - LOG(ERROR) << "Fabric Manager not initialized"; - return; - } - ipcfabric::LibkinetoContext* ctxt = - (ipcfabric::LibkinetoContext*)msg->buf.get(); - int32_t size = -1; - try { - size = LibkinetoConfigManager::getInstance()->registerLibkinetoContext( - std::to_string(ctxt->jobid), ctxt->pid, ctxt->gpu); - } catch (const std::runtime_error& ex) { + std::unique_ptr msg) +{ + if (!ipc_manager_) { + LOG(ERROR) << "Fabric Manager not initialized"; + return; + } + ipcfabric::LibkinetoContext* ctxt = + (ipcfabric::LibkinetoContext*)msg->buf.get(); + int32_t size = -1; + try { + size = LibkinetoConfigManager::getInstance()->registerLibkinetoContext( + std::to_string(ctxt->jobid), ctxt->pid, ctxt->gpu); + } catch (const std::runtime_error& ex) { LOG(ERROR) << "Kineto config manager exception : " << ex.what(); - } - std::unique_ptr ret = - ipcfabric::Message::constructMessage( - size, kLibkinetoContext); - if (!ipc_manager_->sync_send(*ret, msg->src)) { - LOG(ERROR) << "Failed to send ctxt from dyno: IPC sync_send fail"; - } - - return; + } + std::unique_ptr ret = + ipcfabric::Message::constructMessage( + size, kLibkinetoContext); + if (!ipc_manager_->sync_send(*ret, msg->src)) { + LOG(ERROR) << "Failed to send ctxt from dyno: IPC sync_send fail"; + } + return; } } // namespace tracing diff --git a/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.h b/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.h index 25d10e2ec7f..cbc59fd2bbc 100644 --- a/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.h +++ b/msmonitor/dynolog_npu/dynolog/src/tracing/IPCMonitor.h @@ -11,27 +11,34 @@ #define USE_GOOGLE_LOG #include "dynolog/src/ipcfabric/FabricManager.h" +#include "dynolog/src/Logger.h" namespace dynolog { namespace tracing { class IPCMonitor { - public: - using FabricManager = dynolog::ipcfabric::FabricManager; - IPCMonitor(const std::string& ipc_fabric_name = "dynolog"); - virtual ~IPCMonitor() {} - - void loop(); - - public: - virtual void processMsg(std::unique_ptr msg); - void getLibkinetoOnDemandRequest(std::unique_ptr msg); - void registerLibkinetoContext(std::unique_ptr msg); - - std::unique_ptr ipc_manager_; +public: + using FabricManager = dynolog::ipcfabric::FabricManager; + IPCMonitor(const std::string& ipc_fabric_name = "dynolog"); + virtual ~IPCMonitor() {} + + void loop(); + void dataLoop(); + +public: + virtual void processMsg(std::unique_ptr msg); + virtual void processDataMsg(std::unique_ptr msg); + void getLibkinetoOnDemandRequest(std::unique_ptr msg); + void registerLibkinetoContext(std::unique_ptr msg); + void setLogger(std::unique_ptr logger); + void LogData(const nlohmann::json& result); + + std::unique_ptr ipc_manager_; + std::unique_ptr data_ipc_manager_; + std::unique_ptr logger_; // friend class test_case_name##_##test_name##_Test - friend class IPCMonitorTest_LibkinetoRegisterAndOndemandTest_Test; + friend class IPCMonitorTest_LibkinetoRegisterAndOndemandTest_Test; }; } // namespace tracing -- Gitee