diff --git a/plugins/mindstudio-insight-plugins/ModelVis/Cargo.lock b/plugins/mindstudio-insight-plugins/ModelVis/Cargo.lock index c038a7e2065f822e1a6bd3e1d382d221a4792d3d..32daa20846408c9df69e93120f702a0b9c3a11c2 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/Cargo.lock +++ b/plugins/mindstudio-insight-plugins/ModelVis/Cargo.lock @@ -572,6 +572,7 @@ dependencies = [ name = "csv_parser" version = "0.1.0" dependencies = [ + "ahash", "anyhow", "csv", "serde", @@ -4394,6 +4395,7 @@ version = "0.1.0" dependencies = [ "ahash", "anyhow", + "csv_parser", "fsg", "layout", "parser", diff --git a/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/Cargo.toml b/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/Cargo.toml index c81a357748ef772f41228f5840c1d1667dba3834..2f20e6147011541da58c2a1e2c1701ee85202123 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/Cargo.toml +++ b/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/Cargo.toml @@ -18,6 +18,7 @@ smartstring = { workspace = true } ahash = { workspace = true, features = ["serde"] } serde = { workspace = true, features = ["derive"] } serde_json.workspace = true +csv_parser = { path = "../../rust/csv_parser" } parser = { path = "../../rust/parser" } layout = { path = "../../rust/layout" } fsg = { path = "../../rust/fsg" } diff --git a/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/src/commands.rs b/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/src/commands.rs index ab5f7458b814554925c65c6c90054abad1e0b1cb..9f5bfbc2b14942f5e9e6e7df5e73c796de2cc906 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/src/commands.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/src/commands.rs @@ -10,6 +10,8 @@ use fsg::{result::JSONResult, fsgs_model}; use tauri::{path::BaseDirectory, AppHandle, Emitter, Manager, Result as InvokeResult}; use tauri::async_runtime::spawn_blocking; use tauri_plugin_shell::ShellExt; +use csv_parser::operator::OperatorGroup; +use csv_parser::parse_operator_csv; #[allow(non_snake_case)] #[derive(Debug, Clone, Serialize, Deserialize)] @@ -388,7 +390,7 @@ fn read_from<'a, T: DeserializeOwned, P: AsRef>(path: P) -> T { } #[tauri::command] -pub fn mine_fsg(path: &str, name: &str, min_sup: usize, min: usize, max: usize) -> InvokeResult> { +pub async fn mine_fsg(path: &str, name: &str, min_sup: usize, min: usize, max: usize) -> InvokeResult> { let mut model = parse_bin(&path)?; let single_graph = if model.name == name { Model { @@ -425,6 +427,11 @@ fn recursive_get_single_graph(model: &mut Model, name: &str, depth: usize) -> Re Err(anyhow!("Subgraph with name '{}' not found", name)) } +#[tauri::command] +pub async fn analyze_duration(path: &str) -> InvokeResult> { + wrap(parse_operator_csv(path)) +} + fn wrap(src: Result) -> InvokeResult { src.map_err(|e| e.into()) } diff --git a/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/src/lib.rs b/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/src/lib.rs index 83c6237d48507ead38cc97cf652f3f95fca0d70a..e37f2d4a1627c468ffdd6705302711f8821f9593 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/src/lib.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/src/lib.rs @@ -1,13 +1,13 @@ mod commands; -use commands::{layout_bin, mine_fsg}; +use commands::{layout_bin, mine_fsg, analyze_duration}; #[cfg_attr(mobile, tauri::mobile_entry_point)] pub fn run() { tauri::Builder::default() .plugin(tauri_plugin_dialog::init()) .plugin(tauri_plugin_shell::init()) - .invoke_handler(tauri::generate_handler![layout_bin, mine_fsg]) + .invoke_handler(tauri::generate_handler![layout_bin, mine_fsg, analyze_duration]) .run(tauri::generate_context!()) .expect("error while running tauri application"); } diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/Cargo.toml b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/Cargo.toml index 5dd1fa0c1e9a9266730bb49514e2c57b2df83fad..3634b7f90de645c2ab469a2cb5b10aa447a793fe 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/Cargo.toml +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/Cargo.toml @@ -6,5 +6,6 @@ edition = "2024" [dependencies] anyhow = { workspace = true } thiserror = { workspace = true } +ahash = { workspace = true, features = ["serde"] } csv = "1.3.1" serde = { version = "1.0", features = ["derive"] } diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/src/lib.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/src/lib.rs index c47534cacb6d7af8ee2d203800603e3a33c177d9..d6b3d216d1be34f978454c64de7224fe8e24545a 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/src/lib.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/src/lib.rs @@ -1,33 +1,55 @@ /* * Copyright (c), Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. */ -use std::collections::HashMap; +use ahash::HashMap; use std::path::Path; use anyhow::{anyhow, Result}; -use crate::operator::{OperatorGroup, DURATION_KEY, MTE2_TIME_KEY, NAME_KEY, TYPE_KEY}; -use crate::parser::ValidHeaderField; +use crate::operator::{group_times_by_op_type, OperatorGroup}; +use crate::parser::{parse_operator, ParseFile}; pub mod operator; pub mod parser; +mod utils; + +const KERNEL_DETAIL_CSV: &str = "kernel_details.csv"; pub fn is_valid_csv(name: &str) -> bool { - todo!() + // 情况1:精确匹配 + if name == KERNEL_DETAIL_CSV { + return true; + } + + // 情况2:匹配 op_summary{*}.csv + if !name.starts_with("op_summary") || !name.ends_with(".csv") { + return false; + } + + true } -pub fn get_valid_fields(name: &str) -> Vec { - todo!() +pub fn get_parse_file(name: &str) -> ParseFile { + if name == KERNEL_DETAIL_CSV { + return ParseFile::KernelDetail; + } + ParseFile::OpSummary } pub fn parse_operator_csv(path: &str) -> Result> { - todo!() + let name = get_filename(path)?; + if !is_valid_csv(&name) { + return Err(anyhow!("Invalid CSV filename: {}", &name)); + } + let file_type = get_parse_file(&name); + let operators = parse_operator(path, file_type)?; + + Ok(group_times_by_op_type(&operators)) } pub fn get_filename(path: &str) -> Result { let path = Path::new(path); - let file_name = match path.file_name() { - Some(name) => name.to_string_lossy(), - None => return Err(anyhow!("没有文件名: {}", path.display())), - }; + let file_name = path.file_name() + .ok_or_else(|| anyhow::anyhow!("No filename component in path: {}", path.display()))?; - Ok(file_name.as_ref().to_string()) + // to_string_lossy() 返回 Cow,直接调用 to_string() 即可 + Ok(file_name.to_string_lossy().to_string()) } \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/src/operator.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/src/operator.rs index c1ff226c540829e46d8cc084c77dab2f7c75bd22..fe8e66e1b9ee7fe4f32709bc32a7b595b6366ecf 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/src/operator.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/src/operator.rs @@ -1,76 +1,70 @@ /* * Copyright (c), Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. */ -use std::collections::HashMap; +use std::collections::BTreeMap; +use ahash::HashMap; use std::string::ToString; use serde::Serialize; -use anyhow::Result; -use thiserror::Error; -pub const NAME_KEY: &str = "name"; -pub const TYPE_KEY: &str = "op_type"; -pub const DURATION_KEY: &str = "duration_us"; -pub const MTE2_TIME_KEY: &str = "mte2_time_us"; - -#[derive(Error, Debug)] -pub enum CreateError { - #[error("转换数字失败: {0}")] - FailTransformNumber(String), - - #[error("无效的数字")] - InvalidNumber, -} - -pub trait Creator { - fn create_by_row(row: &HashMap<&str, &str>) -> Result where Self: Sized; -} +use crate::parser::{OperatorFromKernelDetail, OperatorFromOpSummary}; #[derive(Debug, Serialize)] pub struct Operator { pub name: String, pub op_type: String, pub duration_us: f64, // 算子耗时 - pub mte2_time_us: f64, // 算子内存搬运时间 -} - -impl Default for Operator { - fn default() -> Self { - Operator { - name: "".to_string(), - op_type: "".to_string(), - duration_us: 0.0, - mte2_time_us: 0.0, - } - } + pub memory_move_time_us: f64, // 算子内存搬运时间 } #[derive(Debug, Serialize)] pub struct OperatorGroup { - op_type: String, - duration_us: f64, - mte2_time_us: f64, + pub op_type: String, + pub avg_duration_us: f64, + pub avg_memory_move_time_us: f64, } -impl From for OperatorGroup { - fn from(value: Operator) -> Self { - OperatorGroup { - op_type: value.op_type, - duration_us: value.duration_us, - mte2_time_us: value.mte2_time_us, - } - } -} +pub fn group_times_by_op_type(operators: &Vec) -> HashMap { + // 创建一个临时 HashMap 用于存储每种类型的操作符信息 + let mut type_map: BTreeMap = BTreeMap::new(); + + // 使用 fold 方法来累积数据 + operators.iter().fold(&mut type_map, |map, operator| { + let duration = operator.duration_us; + let mte2_time = operator.memory_move_time_us; + let entry = map.entry(operator.op_type.to_string()).or_insert((0.0, 0.0, 0)); + entry.0 += duration; + entry.1 += mte2_time; + entry.2 += 1; + map + }); -pub fn group_times_by_op_type(operators: &Vec) -> Vec { - vec![] + // 计算平均值并转换为 OperatorGroup + type_map.into_iter().map(|(op_type, (total_duration, total_mte2, count))| { + (op_type.clone(), OperatorGroup { + op_type, + avg_duration_us: total_duration / count as f64, + avg_memory_move_time_us: total_mte2 / count as f64, + }) + }).collect() } -impl Creator for Operator { - fn create_by_row(row: &HashMap<&str, &str>) -> Result { - todo!() +impl From for Operator { + fn from(value: OperatorFromKernelDetail) -> Self { + Self { + name: value.name, + op_type: value.type_field, + duration_us: value.duration_us, + memory_move_time_us: value.aiv_mte2_time_us + value.aiv_mte3_time_us, + } } } -/// 解析 string 转为 f64 -fn transform_f64(field: &str) -> Result { - todo!() +impl From for Operator { + fn from(value: OperatorFromOpSummary) -> Self { + Self { + name: value.name, + op_type: value.type_field, + duration_us: value.duration_us, + memory_move_time_us: value.aiv_mte2_time_us + value.aiv_mte3_time_us, + } + } } diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/src/parser.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/src/parser.rs index 87902ac976592c8d81a6b1bc8705fb7b42503659..fc77e21d2801add7c3b70f594860f6594fc880bb 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/src/parser.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/src/parser.rs @@ -1,64 +1,180 @@ /* * Copyright (c), Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. */ -use std::collections::HashMap; +use std::fs::File; use std::io::BufReader; use csv::Reader; -use serde::Serialize; +use serde::de::DeserializeOwned; +use serde::Deserialize; use thiserror::Error; -use crate::operator::{CreateError, Creator, Operator}; -#[derive(Debug, Serialize)] -pub struct ParseResult { - pub rows: Vec, - pub warnings: Vec, +use crate::operator::Operator; +use crate::utils::check_csv_injection; + +// 自定义解析函数:将字符串解析为 f64(自动提取数字,忽略单位) +fn parse_micros_f64<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + if s.starts_with("N/A") { // 如果 csv 有 N/A, 返回-1,在之后的操作中会过滤掉负数的行 + return Ok(-1.0); + } + // 只保留数字、小数点、负号 + let cleaned: String = s + .chars() + .filter(|c| c.is_ascii_digit() || *c == '.' || *c == '-') + .collect(); + + if cleaned.is_empty() { + return Err(serde::de::Error::custom("empty or invalid number")); + } + + cleaned + .parse::() + .map_err(|_| serde::de::Error::custom("failed to parse as f64")) +} + +// 设置来自 kernel_detail.csv 文件的解析 +#[derive(Debug, Deserialize)] +pub struct OperatorFromKernelDetail { + #[serde(rename = "Name")] + pub name: String, + + #[serde(rename = "Type")] + pub type_field: String, + + #[serde(rename = "Duration(us)", deserialize_with = "parse_micros_f64")] + pub duration_us: f64, + + #[serde(rename = "aiv_mte2_time(us)", deserialize_with = "parse_micros_f64")] + pub aiv_mte2_time_us: f64, + + #[serde(rename = "aiv_mte3_time(us)", deserialize_with = "parse_micros_f64")] + pub aiv_mte3_time_us: f64, +} + +// 设置来自 op_summary_{时间戳}.csv 文件的解析 +#[derive(Debug, Deserialize)] +pub struct OperatorFromOpSummary { + #[serde(rename = "Op Name")] + pub name: String, + + #[serde(rename = "OP Type")] + pub type_field: String, + + #[serde(rename = "Task Duration(us)", deserialize_with = "parse_micros_f64")] + pub duration_us: f64, + + #[serde(rename = "aiv_mte2_time(us)", deserialize_with = "parse_micros_f64")] + pub aiv_mte2_time_us: f64, + + #[serde(rename = "aiv_mte3_time(us)", deserialize_with = "parse_micros_f64")] + pub aiv_mte3_time_us: f64, } #[derive(Error, Debug)] pub enum ParseError { - #[error("CSV 格式错误: {0}")] + #[error("CSV format error: {0}")] CsvError(#[from] csv::Error), - #[error("IO 错误: {0}")] + #[error("IO error: {0}")] IoError(#[from] std::io::Error), - #[error("文件为空或无有效数据")] + #[error("File is empty or has invalid data")] EmptyData, +} - #[error("文件头为空")] - EmptyHeader, - - #[error("文件头格式错误,缺少需要的列: {0}")] - InvalidHeader(String), +#[derive(Debug)] +pub enum ParseFile { + KernelDetail, + OpSummary, +} - #[error("创建失败: {0}")] - CreateFail(String), +// 定义一个Trait,该Trait提供对共通属性的基本访问。 +pub trait OperatorInfo { + fn name(&self) -> &str; + fn type_field(&self) -> &str; + fn duration_us(&self) -> f64; + fn aiv_mte2_time_us(&self) -> f64; + fn aiv_mte3_time_us(&self) -> f64; } -pub struct ValidHeaderField { - pub h_field: String, - pub h_key: String, +impl OperatorInfo for OperatorFromKernelDetail { + fn name(&self) -> &str { + &self.name + } + fn type_field(&self) -> &str { + &self.type_field + } + fn duration_us(&self) -> f64 { + self.duration_us + } + fn aiv_mte2_time_us(&self) -> f64 { + self.aiv_mte2_time_us + } + fn aiv_mte3_time_us(&self) -> f64 { + self.aiv_mte3_time_us + } } -pub fn parse_operator(path: &str, parse_fields: Vec) -> Result, ParseError> { - todo!() +impl OperatorInfo for OperatorFromOpSummary { + fn name(&self) -> &str { + &self.name + } + fn type_field(&self) -> &str { + &self.type_field + } + fn duration_us(&self) -> f64 { + self.duration_us + } + fn aiv_mte2_time_us(&self) -> f64 { + self.aiv_mte2_time_us + } + fn aiv_mte3_time_us(&self) -> f64 { + self.aiv_mte3_time_us + } } -/// 从字节流中解析 CSV,返回算子列表与结构汇总 -pub fn parse_csv(reader: R, fields: Vec, creator: F) - -> Result, ParseError> -where - T: Creator, - F: Fn(&HashMap<&str, &str>) -> Result { - todo!() +pub fn parse_operator(path: &str, parse_file: ParseFile) -> Result, ParseError> { + let file = File::open(path)?; + let reader = BufReader::new(file); + + // 解析 CSV 文件 + let result = match parse_file { + ParseFile::KernelDetail => { + parse_csv::<_, OperatorFromKernelDetail>(reader)? + }, + ParseFile::OpSummary => { + parse_csv::<_, OperatorFromOpSummary>(reader)? + }, + }; + + Ok(result) // 返回解析后的 Operator 列表 } -fn get_header_index_map(csv_reader: &mut Reader, fields: Vec) - -> Result, ParseError> { - todo!() +fn valid_operator(row_count: usize, operator: &T) -> bool { + operator.name() == "" || check_csv_injection(&operator.name(), row_count) || + operator.type_field() == "" || check_csv_injection(&operator.type_field(), row_count) || + operator.duration_us() < 0.0 || operator.aiv_mte2_time_us() < 0.0 || + operator.aiv_mte3_time_us() < 0.0 } -/// 安全过滤:防止 CSV 注入(公式注入) -/// 常见恶意前缀:= + @ - -fn check_csv_injection(field: &str, row_count: usize, warnings: &mut Vec) -> bool { - todo!() +pub fn parse_csv(reader: R) -> Result, ParseError> +where T: OperatorInfo + DeserializeOwned + Into, { + let mut reader = Reader::from_reader(reader); + let mut operators = Vec::new(); + + for (row_count, result) in reader.deserialize().enumerate() { + let op_detail: T = result?; + if valid_operator(row_count + 1, &op_detail) { + continue; + } + let operator: Operator = op_detail.into(); // 转换 + operators.push(operator); + } + + if operators.len() < 1 { + return Err(ParseError::EmptyData); + } + Ok(operators) } diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/src/utils.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/src/utils.rs new file mode 100644 index 0000000000000000000000000000000000000000..2138b733d5aefbf9195c82e580c5292fd2e71808 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/src/utils.rs @@ -0,0 +1,45 @@ +/// 安全过滤:防止 CSV 注入(公式注入) +/// 常见恶意前缀:= + @ - +pub fn check_csv_injection(field: &str, row_count: usize) -> bool { + if let Some(first_char) = field.chars().next() { + if "=+@-".contains(first_char) { + let e_field = escape_special_chars(field); + eprintln!("Line {}, Potential CSV formula injection detected, filtered: {}", row_count, &e_field); + return true; + } + } + false +} + +/// 处理字符串中的特殊符号,例如将 \n 替换为 \\n, \t 替换为 \\t 等 +/// +/// # 参数 +/// * `input` - 需要处理的原始字符串 +/// +/// # 返回值 +/// * 处理后的字符串 +pub fn escape_special_chars(input: &str) -> String { + let mut result = input.to_string(); + + // 特殊字符及其对应的转义序列 + let special_chars = [ + ("\n", "\\n"), + ("\t", "\\t"), + ("\r", "\\r"), + ("\"", "\\\""), + ("'", "\\'"), + ("\x08", "\\b"), // \b + ("\x0c", "\\f"), // \f + ("\x0b", "\\v"), // \v + ("\\", "\\\\"), + ]; + + // 使用 for 循环遍历并替换所有特殊字符 + for (char, escaped_char) in special_chars.iter() { + result = result.replace(*char, *escaped_char); + } + + result +} + + diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/invalid/kernel_detail.csv b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/invalid/kernel_details.csv similarity index 100% rename from plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/invalid/kernel_detail.csv rename to plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/invalid/kernel_details.csv diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/invalid/kernel_detail.txt b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/invalid/kernel_details.txt similarity index 100% rename from plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/invalid/kernel_detail.txt rename to plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/invalid/kernel_details.txt diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/parser_test.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/parser_test.rs index 16847705449539ffb3c338f4580a19c01430a3f6..34c0a407d660fb6e595430efde2c179fd5b6b6a7 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/parser_test.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/parser_test.rs @@ -4,65 +4,115 @@ #[cfg(test)] mod parser_tests { use std::io::Cursor; - use csv_parser::parser::{parse_csv, ValidHeaderField}; - use csv_parser::operator::{Creator, Operator, DURATION_KEY, MTE2_TIME_KEY, NAME_KEY, TYPE_KEY}; + use csv_parser::parser::{parse_csv, OperatorFromKernelDetail, OperatorFromOpSummary}; - fn get_fields() -> Vec { - vec![ - ValidHeaderField { h_field: "op_name".to_string(), h_key: NAME_KEY.to_string() }, - ValidHeaderField { h_field: "op_type".to_string(), h_key: TYPE_KEY.to_string() }, - ValidHeaderField { h_field: "duration".to_string(), h_key: DURATION_KEY.to_string() }, - ValidHeaderField { h_field: "mte2_time".to_string(), h_key: MTE2_TIME_KEY.to_string() }, - ] + #[test] + fn test_parse_valid_kernel_detail_csv() { + let csv_data = "Name,Type,Duration(us),aiv_mte2_time(us),aiv_mte3_time(us),Other\nconv2d,MatMul,120.0,20,200,\nrelu,Active,30,3,4,\nmatmul,MatMul,200,20,40,"; + let cursor = Cursor::new(csv_data); + let result = parse_csv::<_, OperatorFromKernelDetail>(cursor).unwrap(); + assert_eq!(result.len(), 3); + assert_eq!(result[0].name, "conv2d"); + assert_eq!(result[0].duration_us, 120.0); + assert_eq!(result[0].memory_move_time_us, 220.0); } #[test] - fn test_parse_valid_csv() { - let csv_data = "op_name,op_type,duration,mte2_time\nconv2d,MatMul,120.0,20\nrelu,Active,30,3\nmatmul,MatMul,200,20"; + fn test_parse_valid_op_summary_csv() { + let csv_data = "Op Name,OP Type,Task Duration(us),aiv_mte2_time(us),aiv_mte3_time(us),Other\nconv2d,MatMul,120.0,20,200,\nrelu,Active,30,3,4,\nmatmul,MatMul,200,20,40,"; let cursor = Cursor::new(csv_data); - let fields = get_fields(); - let result = parse_csv(cursor, fields, Operator::create_by_row) - .unwrap(); + let result = parse_csv::<_, OperatorFromOpSummary>(cursor).unwrap(); + assert_eq!(result.len(), 3); + assert_eq!(result[0].name, "conv2d"); + assert_eq!(result[0].duration_us, 120.0); + assert_eq!(result[0].memory_move_time_us, 220.0); + } - assert_eq!(result.rows.len(), 3); - assert_eq!(result.rows[0].name, "conv2d"); - assert_eq!(result.rows[0].duration_us, 120.0); - assert_eq!(result.rows[0].mte2_time_us, 20.0); + #[test] + fn test_parse_with_duplicate_mte2_column() { + let csv_data = "Name,Type,Duration(us),aiv_mte2_time(us),aiv_mte2_time(us)\nconv2d,MatMul,1,1,0\nrelu,Active,30,3,4"; + let cursor = Cursor::new(csv_data); + let result = parse_csv::<_, OperatorFromKernelDetail>(cursor); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "CSV format error: CSV deserialize error: record 1 (line: 2, byte: 59): duplicate field `aiv_mte2_time(us)`"); + } + + #[test] + fn test_parse_with_missing_mte3_column() { + let csv_data = "Name,Type,Duration(us),aiv_mte2_time(us)\nconv2d,MatMul,1,1\nrelu,Active,30,3"; + let cursor = Cursor::new(csv_data); + let result = parse_csv::<_, OperatorFromKernelDetail>(cursor); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "CSV format error: CSV deserialize error: record 1 (line: 2, byte: 41): missing field `aiv_mte3_time(us)`"); + } + + #[test] + fn test_parse_with_na_duration_column() { + let csv_data = "Name,Type,Duration(us),aiv_mte2_time(us),aiv_mte3_time(us)\nconv2d,MatMul,N/A,N/A,N/A\nrelu,Active,1,1,1"; + let cursor = Cursor::new(csv_data); + let result = parse_csv::<_, OperatorFromKernelDetail>(cursor).unwrap(); + assert_eq!(result[0].name, "relu"); } #[test] fn test_parse_with_invalid_duration() { - let csv_data = "op_name,op_type,duration,mte2_time\nconv2d,MatMul,abc,1\nrelu,Active,30,3"; + let csv_data = "Name,Type,Duration(us),aiv_mte2_time(us),aiv_mte3_time(us)\nconv2d,MatMul,abc,1,0\nrelu,Active,30,3,4"; + let cursor = Cursor::new(csv_data); + let result = parse_csv::<_, OperatorFromKernelDetail>(cursor); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "CSV format error: CSV deserialize error: record 1 (line: 2, byte: 59): empty or invalid number"); + } + + #[test] + fn test_parse_with_blank_duration_kernel_detail_csv() { + let csv_data = "Name,Type,Duration(us),aiv_mte2_time(us),aiv_mte3_time(us)\nconv2d,MatMul,1,1,0\nrelu,Active,,3,4"; let cursor = Cursor::new(csv_data); - let fields = get_fields(); - let result = parse_csv(cursor, fields, Operator::create_by_row) - .unwrap(); + let result = parse_csv::<_, OperatorFromKernelDetail>(cursor); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "CSV format error: CSV deserialize error: record 2 (line: 3, byte: 79): empty or invalid number"); + } - assert_eq!(result.rows.len(), 1); // 只有 relu 被解析 - assert_eq!(result.warnings.len(), 1); - assert!(result.warnings[0].contains("abc")); + #[test] + fn test_parse_with_blank_type_kernel_detail_csv() { + let csv_data = "Name,Type,Duration(us),aiv_mte2_time(us),aiv_mte3_time(us)\nconv2d,,1,1,0\nrelu,Active,30,3,4"; + let cursor = Cursor::new(csv_data); + let result = parse_csv::<_, OperatorFromKernelDetail>(cursor).unwrap(); + assert_eq!(result[0].op_type, "Active"); } #[test] - fn test_csv_injection_filter() { - let csv_data = "op_name,op_type,duration,mte2_time\n=CMD|' /C calc'!A0,A,100,10\nsafe_op,Op,50,5"; + fn test_parse_with_blank_duration_op_summary_csv() { + let csv_data = "Op Name,OP Type,Task Duration(us),aiv_mte2_time(us),aiv_mte3_time(us)\nconv2d,MatMul,120.0,20,200\nrelu,Active,,3,4"; let cursor = Cursor::new(csv_data); - let fields = get_fields(); - let result = parse_csv(cursor, fields, Operator::create_by_row) - .unwrap(); + let result = parse_csv::<_, OperatorFromOpSummary>(cursor); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "CSV format error: CSV deserialize error: record 2 (line: 3, byte: 97): empty or invalid number"); + } + + #[test] + fn test_parse_with_blank_name_op_summary_csv() { + let csv_data = "Op Name,OP Type,Task Duration(us),aiv_mte2_time(us),aiv_mte3_time(us)\n,MatMul,120.0,20,200\nrelu,Active,30,3,4"; + let cursor = Cursor::new(csv_data); + let result = parse_csv::<_, OperatorFromOpSummary>(cursor).unwrap(); + assert_eq!(result[0].name, "relu"); + } - assert_eq!(result.rows.len(), 1); - assert_eq!(result.rows[0].name, "safe_op"); + #[test] + fn test_csv_injection_filter() { + let csv_data = "Op Name,OP Type,Task Duration(us),aiv_mte2_time(us),aiv_mte3_time(us)\n=CMD|' /C calc'!A0,A,100,10,20\nsafe_op,Op,50,5,6"; + let cursor = Cursor::new(csv_data); + let result = parse_csv::<_, OperatorFromOpSummary>(cursor).unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].name, "safe_op"); // 第一行因以 '=' 开头被过滤 } #[test] fn test_empty_csv() { - let csv_data = "op_name,op_type,duration,mte2_time"; // 无数据行 + let csv_data = "Op Name,OP Type,Task Duration(us),aiv_mte2_time(us),aiv_mte3_time(us)"; // 无数据行 let cursor = Cursor::new(csv_data); - let fields = get_fields(); - let result = parse_csv(cursor, fields, Operator::create_by_row); + let result = parse_csv::<_, OperatorFromOpSummary>(cursor); assert!(result.is_err()); - assert_eq!(result.unwrap_err().to_string(), "文件为空或无有效数据"); + assert_eq!(result.unwrap_err().to_string(), "File is empty or has invalid data"); } } \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/test.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/test.rs index fafdbb54454833aa9ef656a22eaa52c14175a6f2..58ae1d7a167ab5413f75e4d2e9116077c6fad3f3 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/test.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/test.rs @@ -11,7 +11,7 @@ mod tests { // 获取当前工作目录 let current_dir = env::current_dir().expect("Failed to get current directory"); - let filename = current_dir.join(r#"tests\valid\kernel_detail.csv"#); + let filename = current_dir.join(r#"tests\valid\kernel_details.csv"#); let path = filename.to_str().expect("Failed to convert path to string"); println!("{}", path); let name = get_filename(path).unwrap(); @@ -27,7 +27,7 @@ mod tests { fn test_invalid_path() { // 获取当前工作目录 let current_dir = env::current_dir().expect("Failed to get current directory"); - let filename = current_dir.join(r#"tests\invalid\kernel_detail.txt"#); + let filename = current_dir.join(r#"tests\invalid\kernel_details.txt"#); let path = filename.to_str().expect("Failed to convert path to string"); let name = get_filename(path).unwrap(); assert_eq!(is_valid_csv(&name), false); @@ -37,13 +37,17 @@ mod tests { fn test_parse_valid_csv() { // 获取当前工作目录 let current_dir = env::current_dir().expect("Failed to get current directory"); - let filename = current_dir.join(r#"tests\invalid\kernel_detail.csv"#); + let filename = current_dir.join(r#"tests\valid\kernel_details.csv"#); let path = filename.to_str().expect("Failed to convert path to string"); match parse_operator_csv(path) { Ok(ops) => { - assert_eq!(14, ops.len()); + assert_eq!(ops.len(), 2); + assert_eq!(ops.get("MatMul").unwrap().avg_duration_us, 120.0); + assert_eq!(ops.get("MatMul").unwrap().avg_memory_move_time_us, 23.0); + assert_eq!(ops.get("Active").unwrap().avg_duration_us, 30.0); + assert_eq!(ops.get("Active").unwrap().avg_memory_move_time_us, 7.0); }, - Err(err) => panic!("Error occurred while parsing ONNX model: {}", err), + Err(err) => panic!("Error occurred while parsing csv: {}", err), } } @@ -51,7 +55,7 @@ mod tests { fn test_parse_invalid_csv() { // 获取当前工作目录 let current_dir = env::current_dir().expect("Failed to get current directory"); - let filename = current_dir.join(r#"tests\invalid\kernel_detail.csv"#); + let filename = current_dir.join(r#"tests\invalid\kernel_details.csv"#); let path = filename.to_str().expect("Failed to convert path to string"); let result = parse_operator_csv(path); assert!(result.is_err()); diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/valid/kernel_detail.csv b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/valid/kernel_detail.csv deleted file mode 100644 index 93bff0ced6b7c129c112f7cf2cc977d9ecf169fd..0000000000000000000000000000000000000000 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/valid/kernel_detail.csv +++ /dev/null @@ -1,4 +0,0 @@ -Name,Type,Duration(us),aiv_mte2_time(us), -conv2d,MatMul,120.0,20, -relu,Active,30,3, -matmul,MatMul,200,20 \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/valid/kernel_details.csv b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/valid/kernel_details.csv new file mode 100644 index 0000000000000000000000000000000000000000..40b48ac58d69b031a3bf211a5e0492c40c9d3f6e --- /dev/null +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/valid/kernel_details.csv @@ -0,0 +1,4 @@ +Name,Type,Duration(us),aiv_mte2_time(us),aiv_mte3_time(us), +conv2d,MatMul,120.0,20,3, +relu,Active,30,3,4, +matmul,MatMul,200,N/A,5, \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/valid/op_summary_20250812203130.csv b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/valid/op_summary_20250812203130.csv index aecbb02435a4164e1ab8dfcf1ead7134b66cdad2..fd05c1f23500adbd8ece9e320f3df4a4978d90be 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/valid/op_summary_20250812203130.csv +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/csv_parser/tests/valid/op_summary_20250812203130.csv @@ -1,4 +1,4 @@ -Op Name,Op Type,Task Duration(us),aiv_mte2_time(us), -conv2d,MatMul,120.0,20, -relu,Active,30,3, -matmul,MatMul,200,20 \ No newline at end of file +Op Name,OP Type,Task Duration(us),aiv_mte2_time(us),aiv_mte3_time(us),other +conv2d,MatMul,120.0,20,1, +relu,Active,30,3,2, +matmul,MatMul,200,20,3, \ No newline at end of file