From de925f09e421392f94621b1b4b5d0b0a23a9464c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A9=BA=E7=99=BD?= <3440771474@qq.com> Date: Sat, 7 Dec 2024 17:21:37 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20ModelVis=E8=A7=A3=E6=9E=90=E6=A8=A1?= =?UTF-8?q?=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ModelVis/rust/parser/Cargo.toml | 17 ++ .../ModelVis/rust/parser/build.rs | 11 + .../ModelVis/rust/parser/proto/mind_ir.proto | 141 +++++++++++ .../ModelVis/rust/parser/src/lib.rs | 17 ++ .../ModelVis/rust/parser/src/main.rs | 116 +++++++++ .../ModelVis/rust/parser/src/model.rs | 33 +++ .../rust/parser/src/processors/mind_ir.rs | 222 +++++++++++++++++ .../rust/parser/src/processors/mod.rs | 1 + .../ModelVis/rust/parser/src/quick.rs | 234 ++++++++++++++++++ .../ModelVis/rust/parser/src/str_ext.rs | 77 ++++++ .../rust/parser/tests/test_strip_prefix.rs | 82 ++++++ 11 files changed, 951 insertions(+) create mode 100644 plugins/mindstudio-insight-plugins/ModelVis/rust/parser/Cargo.toml create mode 100644 plugins/mindstudio-insight-plugins/ModelVis/rust/parser/build.rs create mode 100644 plugins/mindstudio-insight-plugins/ModelVis/rust/parser/proto/mind_ir.proto create mode 100644 plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/lib.rs create mode 100644 plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/main.rs create mode 100644 plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/model.rs create mode 100644 plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/mind_ir.rs create mode 100644 plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/mod.rs create mode 100644 plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/quick.rs create mode 100644 plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/str_ext.rs create mode 100644 plugins/mindstudio-insight-plugins/ModelVis/rust/parser/tests/test_strip_prefix.rs diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/Cargo.toml b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/Cargo.toml new file mode 100644 index 000000000..5413cfc99 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/Cargo.toml @@ -0,0 +1,17 @@ +cargo-features = ["edition2024"] + +[package] +name = "parser" +version = "0.1.0" +edition = "2024" + +[dependencies] +prost = { version = "0.13.3" } +quick-protobuf = "0.8.1" +smartstring = {version = "1.0.1", features = ["serde"]} +ahash = { version = "0.8.11" } +serde = { version = "1.0.214", features = ["derive"] } +serde_json = "1.0.132" + +[build-dependencies] +prost-build = { version = "0.13.3" } diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/build.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/build.rs new file mode 100644 index 000000000..cd58cdd65 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/build.rs @@ -0,0 +1,11 @@ +use prost_build::Config; + +fn main() { + let mut config = Config::new(); + config.type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]"); + config + .compile_protos(&["proto/mind_ir.proto", "proto/ge_ir.proto", "proto/ge_onnx.proto"], &[ + "proto", + ]) + .unwrap(); +} diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/proto/mind_ir.proto b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/proto/mind_ir.proto new file mode 100644 index 000000000..e70e8f8e1 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/proto/mind_ir.proto @@ -0,0 +1,141 @@ +syntax = "proto2"; + +package mind_ir; + +message AttributeProto { + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + UINT8 = 2; + INT8 = 3; + UINT16 = 4; + INT16 = 5; + INT32 = 6; + INT64 = 7; + STRING = 8; + BOOL = 9; + FLOAT16 = 10; + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; + COMPLEX128 = 15; + BFLOAT16 = 16; + TENSOR = 17; + GRAPH = 18; + TENSORS = 19; + TUPLE = 20; + LIST = 21; + DICT = 22; + UMONAD = 23; + IOMONAD = 24; + NONE = 25; + PRIMITIVECLOSURE = 26; + FUNCGRAPHCLOSURE = 27; + PARTIALCLOSURE = 28; + UNIONFUNCCLOSURE = 29; + CSR_TENSOR = 30; + COO_TENSOR = 31; + ROW_TENSOR = 32; + CLASS_TYPE = 33; + NAME_SPACE = 34; + SYMBOL = 35; + TYPE_NULL = 36; + MAP_TENSOR = 37; + FUNCTOR = 38; + SCALAR = 39; + } + + required string name = 1; + repeated TensorProto tensors = 12; + required AttributeType type = 16; + repeated AttributeProto values = 17; +} + +message ValueInfoProto { + optional string name = 1; + repeated TensorProto tensor = 2; +} + +message NodeProto { + repeated string input = 1; + repeated string output = 2; + required string name = 3; + required string op_type = 4; + repeated AttributeProto attribute = 5; + optional string domain = 7; +} + +message ModelProto { + optional string producer_name = 2; + optional string domain = 4; + optional GraphProto graph = 7; + repeated GraphProto functions = 8; + optional PreprocessorProto preprocessor = 9; + repeated PrimitiveProto primitives = 12; +} + +message PreprocessorProto { + repeated PreprocessOpProto op = 1; +} + +message PreprocessOpProto { + optional string input_columns = 1; + optional string output_columns = 2; + optional string project_columns = 3; + optional string op_type = 4; + optional string operations = 5; + optional bool offload = 6; +} + +message GraphProto { + repeated NodeProto node = 1; + optional string name = 2; + repeated TensorProto parameter = 3; + repeated ValueInfoProto input = 5; + repeated ValueInfoProto output = 6; + repeated AttributeProto attribute = 8; +} + +message TensorProto { + + enum TensorDataType { + UNDEFINED = 0; + FLOAT = 1; + UINT8 = 2; + INT8 = 3; + UINT16 = 4; + INT16 = 5; + INT32 = 6; + INT64 = 7; + STRING = 8; + BOOL = 9; + FLOAT16 = 10; + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; + COMPLEX128 = 15; + BFLOAT16 = 16; + FLOAT64 = 17; + QINT4X2 = 18; + } + + optional string name = 7; + repeated int64 dims = 1; + required TensorDataType dtype = 2; +} + +message PrimitiveProto { + + enum PrimType { + PRIMITIVE = 1; + PRIMITIVE_FUNCTION = 2; + } + + required string name = 1; + required string op_type = 2; + repeated AttributeProto attribute = 3; + required string instance_name = 4; + optional PrimType prim_type = 5; +} diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/lib.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/lib.rs new file mode 100644 index 000000000..d85577f59 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/lib.rs @@ -0,0 +1,17 @@ +pub mod str_ext; +pub mod processors; +pub mod model; + +/// # Export +pub use model::Model; +pub use processors::mind_ir::parse_mindir_model; +pub use str_ext::StrExt; + +/// # TODO +/// Migrate from prost generation to handwritten definition because the performance consumption of `String Clone` is too high +/// +/// # Include +/// prost auto-generated code +pub mod mind_ir { + include!(concat!(env!("OUT_DIR"), "/mind_ir.rs")); +} diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/main.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/main.rs new file mode 100644 index 000000000..2dcc039a1 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/main.rs @@ -0,0 +1,116 @@ +use std::{ + fs::File, + io::{BufReader, Read}, + path::Path, + time::Instant, +}; + +mod proto; +mod quick; + +use prost::Message; +use quick_protobuf::{BytesReader, MessageRead}; +use parser::parse_mindir_model; +use crate::quick::quick_parse_model; + +fn extract_file_name(file_path: &str) -> String { + Path::new(file_path) + .file_name() + .and_then(|name| name.to_str()) + .map(|name| name.to_string()) + .unwrap_or_else(|| String::from("")) +} + +fn prost_simplex(path: &str) { + use parser::mind_ir::ModelProto; + + let start = Instant::now(); + + let file = File::open(path).unwrap(); + let mut reader = BufReader::new(file); + let mut buffer = vec![]; + reader.read_to_end(&mut buffer).unwrap(); + + let _: ModelProto = Message::decode(&*buffer).unwrap(); + + let cost = start.elapsed(); + + println!("prost simplex {} 【{cost:?}】", extract_file_name(path)) +} + +fn quick_simplex(path: &str) { + use proto::mind_ir::ModelProto; + + let start = Instant::now(); + + let file = File::open(path).unwrap(); + let mut buf_reader = BufReader::new(file); + let mut buffer = vec![]; + buf_reader.read_to_end(&mut buffer).unwrap(); + let mut reader = BytesReader::from_bytes(&buffer); + let _ = ModelProto::from_reader(&mut reader, &buffer).expect("Cannot read FooBar"); + let cost = start.elapsed(); + + println!("quick simplex {} 【{cost:?}】", extract_file_name(path)) +} + +fn quick_complex(path: &str) { + use proto::raw::ModelProto; + + let start = Instant::now(); + + let file = File::open(path).unwrap(); + let mut buf_reader = BufReader::new(file); + let mut buffer = vec![]; + buf_reader.read_to_end(&mut buffer).unwrap(); + let mut reader = BytesReader::from_bytes(&buffer); + let _ = ModelProto::from_reader(&mut reader, &buffer).expect("Cannot read FooBar"); + let cost = start.elapsed(); + + println!("quick complex {} 【{cost:?}】", extract_file_name(path)) +} + +const PATH1: &str = r#"E:\pycharm\visual\multi-transformer.mindir"#; +const PATH2: &str = r#"E:\pycharm\visual\multi-transformer64.mindir"#; +const PATH3: &str = r#"E:\pycharm\visual\lenet.mindir"#; +const PATH4: &str = r#"E:\pycharm\visual\alex-net.mindir"#; +const PATH5: &str = r#"E:\pycharm\visual\resnet50_224.mindir"#; + +fn bench(path: &str) { + let s1 = Instant::now(); + quick_parse_model(path); + let cost1 = s1.elapsed(); + let s2 = Instant::now(); + parse_mindir_model(path); + let cost2 = s2.elapsed(); + + println!("quick cost: {cost1:?}, prost cost: {cost2:?}") +} + +fn main() { + prost_simplex(PATH1); + quick_simplex(PATH1); + quick_complex(PATH1); + + prost_simplex(PATH2); + quick_simplex(PATH2); + quick_complex(PATH2); + + prost_simplex(PATH3); + quick_simplex(PATH3); + quick_complex(PATH3); + + prost_simplex(PATH4); + quick_simplex(PATH4); + quick_complex(PATH4); + + prost_simplex(PATH5); + quick_simplex(PATH5); + quick_complex(PATH5); + + bench(PATH1); + bench(PATH2); + bench(PATH3); + bench(PATH4); + bench(PATH5); +} diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/model.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/model.rs new file mode 100644 index 000000000..88dce7ff7 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/model.rs @@ -0,0 +1,33 @@ +use ahash::HashMap; +use serde::{Deserialize, Serialize}; +use smartstring::alias::String; + +#[allow(non_snake_case)] +#[derive(Debug, Serialize, Deserialize)] +pub struct Node { + pub name: String, + pub opType: String, + pub input: Vec, + pub output: Vec, + pub shapes: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Edge { + pub source: String, + pub target: String, +} + +impl Edge { + pub fn new(s: String, t: String) -> Self { + Self { source: s, target: t } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Model { + pub name: String, + pub nodes: HashMap, + // pub edges: Vec, + pub parameters: HashMap, +} diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/mind_ir.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/mind_ir.rs new file mode 100644 index 000000000..aa53112f9 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/mind_ir.rs @@ -0,0 +1,222 @@ +use std::{ + fs::File, + io::{BufReader, Read, Result}, +}; + +use ahash::{HashMap, HashMapExt, HashSet, HashSetExt}; +use crate::mind_ir::{ + AttributeProto, ModelProto, NodeProto, PrimitiveProto, TensorProto, + attribute_proto::AttributeType, +}; +use prost::Message; +use smartstring::alias::String; + +use crate::{ + model::{Model, Node}, + str_ext::StrExt, +}; + +/// TODO: Use `binary` translation to skip unnecessary buffer areas +pub fn parse_pb(path: &str) -> Result { + let file = File::open(path)?; + let mut reader = BufReader::new(file); + let mut buffer = vec![]; + reader.read_to_end(&mut buffer)?; + + let model: ModelProto = Message::decode(&*buffer)?; + + Ok(model) +} + +#[allow(dead_code)] +fn extract_non_ops(primitives: Vec) -> HashSet { + let mut res = HashSet::new(); + + for primitive in &primitives { + if primitive.attribute.is_empty() { + res.insert(primitive.op_type.clone().into()); + } + } + + res +} + +#[inline] +fn extract_op(raw: String) -> String { + if raw.starts_with("REF::") { + raw.try_strip_prefix("REF::").split_once(":").unwrap().0.into() + } else { + raw + } +} + +/// TODO: `Extract functions` and optimize the logic +pub fn parse_mindir_model(path: &str) -> Option { + match parse_pb(path) { + Ok(model_proto) => { + if let Some(graph) = model_proto.graph { + let mut op_types: HashMap = HashMap::new(); + let mut node_names: HashSet = HashSet::new(); + let mut node_name_map: HashMap = HashMap::new(); + + let name = graph.name(); + let size = graph.node.len(); + + let prefix = ""; + + let mut nodes: HashMap = HashMap::with_capacity(size); + // let mut edges = vec![]; + let mut parameters = HashMap::new(); + + for primitive in &model_proto.primitives { + if !primitive.attribute.is_empty() { + op_types.insert( + format!("REF::{}", primitive.name).into(), + primitive.op_type.clone().into(), + ); + } + } + + for parameter in &graph.parameter { + let ps: String = parameter.name.try_strip_prefix(prefix).into(); + parameters.insert(ps.clone(), format_tensor(parameter)); + node_names.insert(ps); + } + + for input in &graph.input { + node_names.insert(input.name.try_strip_prefix(prefix).into()); + } + + for output in &graph.output { + node_names.insert(output.name.try_strip_prefix(prefix).into()); + } + + for node_proto in &graph.node { + let name: String = node_proto.name.try_strip_prefix(prefix).into(); + let op_type: String = node_proto.op_type.clone().into(); + + if !op_types.contains_key(&op_type) { + for i in &node_proto.input { + let is: String = i.clone().into(); + if node_names.contains(&is) { + node_name_map.insert(name.clone(), is); + } + } + continue; + } + + generate_node( + node_proto, + prefix, + &mut nodes, + &mut node_names, + &mut op_types, + &mut node_name_map, + ); + } + + let model = Model { name: name.into(), nodes, parameters }; + + let mut p = name.to_string(); + p.push(':'); + + let s = serde_json::to_string(&model).unwrap().replace(p.as_str(), ""); + let pm: Model = serde_json::from_str(s.as_str()).unwrap(); + + return Some(pm); + } + None + } + _ => None, + } +} + +fn generate_node( + node: &NodeProto, + prefix: &str, + nodes: &mut HashMap, + node_names: &mut HashSet, + op_types: &mut HashMap, + node_name_map: &mut HashMap, +) { + let node_name: String = node.name.clone().into(); + let node_name = node_name.try_strip_prefix(prefix); + let node_name_string: String = node_name.into(); + let op_type: String = node.op_type.clone().into(); + if !op_types.contains_key(&op_type) { + for input in &node.input { + let si: String = input.clone().into(); + if node_names.contains(&si) { + node_name_map.insert(node_name_string.clone(), si); + } + } + + return; + } + + let real_op = extract_op(op_type); + + let raw_attrs = &node.attribute; + + let mut input = vec![]; + let mut output = vec![]; + + let mut shapes = String::new(); + for attr in raw_attrs { + if attr.name == "shape" { + shapes = match attr.r#type() { + AttributeType::Tensors => format_tensors(&attr.tensors), + AttributeType::Tuple => format_tuple(&attr.values), + _ => String::new(), + } + } + } + + node_names.insert(node_name_string.clone()); + + for source in &node.input { + let mut ps: String = source.try_strip_prefix(prefix).into(); + if !node_names.contains(&ps) { + ps = node_name_map.get(&ps).unwrap_or(&"".into()).clone().into(); + } + if !ps.is_empty() { + input.push(ps) + } + } + + for sink in &node.output { + let ps: String = sink.try_strip_prefix(prefix).into(); + output.push(ps) + } + + nodes.insert(node_name_string.clone(), Node { + name: node_name_string, + opType: real_op, + input, + output, + shapes, + }); +} + +#[inline] +fn format_tensor(tensor: &TensorProto) -> String { + let dims = tensor.dims.iter().map(|i| i.to_string().into()).collect::>().join(", "); + + format!("{:?}({})", tensor.dtype(), dims).into() +} + +#[inline] +fn format_tensors(tensors: &Vec) -> String { + let mut s = String::new(); + + for tensor in tensors { + s.push_str(format_tensor(tensor).as_str()) + } + + s +} + +#[inline] +fn format_tuple(values: &Vec) -> String { + format_tensors(&values[0].tensors) +} diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/mod.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/mod.rs new file mode 100644 index 000000000..f0e29e195 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/mod.rs @@ -0,0 +1 @@ +pub mod mind_ir; diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/quick.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/quick.rs new file mode 100644 index 000000000..12df1d5f2 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/quick.rs @@ -0,0 +1,234 @@ +use std::{ + borrow::Cow, + fs::File, + io::{BufReader, Read}, +}; +use std::collections::HashMap; +use ahash::{HashMap, HashMapExt, HashSet, HashSetExt}; +use quick_protobuf::{BytesReader, MessageRead}; +use serde::{Deserialize, Serialize}; + +use crate::{ + proto::mind_ir::{ + AttributeProto, ModelProto, NodeProto, TensorProto, mod_AttributeProto::AttributeType, + mod_TensorProto::TensorDataType, + } +}; + +#[derive(Debug, Serialize, Deserialize)] +struct Tensor { + pub dtype: TensorDataType, + pub dims: Vec, +} + +#[allow(non_snake_case)] +#[derive(Debug, Serialize, Deserialize)] +struct Node<'a> { + pub name: Cow<'a, str>, + pub opType: Cow<'a, str>, + pub input: Vec>, + pub output: Vec>, + shapes: Cow<'a, str>, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Edge<'a> { + pub source: Cow<'a, str>, + pub target: Cow<'a, str>, +} + +impl<'a> Edge<'a> { + pub fn new(s: Cow<'a, str>, t: Cow<'a, str>) -> Self { + Self { source: s, target: t } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Model<'a> { + pub name: Cow<'a, str>, + pub nodes: HashMap, Node<'a>>, + // pub edges: Vec, + parameters: HashMap, Cow<'a, str>>, +} + +#[inline] +fn extract_op(raw: String) -> String { + let m = HashMap::new(); + m.insert(1,1); + if raw.starts_with("REF::") { + raw.try_strip_prefix("REF::").split_once(":").unwrap().0.into() + } else { + raw + } +} + +pub fn quick_parse_model<'a>(path: &str) -> Option { + let file = File::open(path).unwrap(); + let mut buf_reader = BufReader::new(file); + let mut buffer = vec![]; + buf_reader.read_to_end(&mut buffer).unwrap(); + let mut reader = BytesReader::from_bytes(&buffer); + let model_proto = ModelProto::from_reader(&mut reader, &buffer).expect("Cannot read FooBar"); + + if let Some(graph) = model_proto.graph { + let mut op_types: HashMap, Cow<'a, str>> = HashMap::new(); + let mut node_names: HashSet> = HashSet::new(); + let mut node_name_map: HashMap, Cow<'a, str>> = HashMap::new(); + + let name = graph.name.unwrap_or(Cow::from("")); + let size = graph.node.len(); + + let prefix = ""; + + let mut nodes: HashMap, Node> = HashMap::with_capacity(size); + // let mut edges = vec![]; + let mut parameters = HashMap::new(); + + for primitive in &model_proto.primitives { + if !primitive.attribute.is_empty() { + op_types.insert( + format!("REF::{}", primitive.name).into(), + primitive.op_type.clone().into(), + ); + } + } + + for parameter in graph.parameter { + let ps = parameter.name.try_strip_prefix(prefix); + parameters.insert(Cow::from(ps), format_tensor(¶meter)); + node_names.insert(Cow::from(ps)); + } + + for input in graph.input { + node_names.insert(input.name.try_strip_prefix(prefix).into()); + } + + for output in graph.output { + node_names.insert(output.name.try_strip_prefix(prefix).into()); + } + + for node_proto in graph.node { + let name = node_proto.name.try_strip_prefix(prefix); + let op_type: String = node_proto.op_type.clone().into(); + + if !op_types.contains_key(&node_proto.op_type) { + for i in &node_proto.input { + if node_names.contains(i) { + node_name_map.insert(Cow::from(name), i.clone()); + } + } + continue; + } + + generate_node( + &node_proto, + prefix, + &mut nodes, + &mut node_names, + &mut op_types, + &mut node_name_map, + ); + } + + let model = Model { name: name.to_string().into(), nodes, parameters }; + + let mut p = name.to_string(); + p.push(':'); + + let s = serde_json::to_string(&model).unwrap().replace(p.as_str(), ""); + let pm: Model = serde_json::from_str(s.as_str()).unwrap(); + + return Some(pm); + } + + None +} + +fn generate_node<'a>( + node: &'a NodeProto, + prefix: &str, + nodes: &mut HashMap, Node<'a>>, + node_names: &'a mut HashSet>, + op_types: &mut HashMap, Cow<'a, str>>, + node_name_map: &'a mut HashMap, Cow<'a, str>>, +) { + let node_name = node.name.try_strip_prefix(prefix); + + let op_type: String = node.op_type.clone().into(); + if !op_types.contains_key(&node.op_type) { + for input in &node.input { + let si: String = input.clone().into(); + if node_names.contains(input) { + node_name_map.insert(Cow::from(node_name), input.to_owned()); + } + } + + return; + } + + let real_op = extract_op(op_type); + + let raw_attrs = &node.attribute; + + let mut input = vec![]; + let mut output = vec![]; + + let mut shapes = Cow::from(""); + for attr in raw_attrs { + if attr.name == "shape" { + shapes = match attr.type_pb { + AttributeType::TENSORS => format_tensors(&attr.tensors), + AttributeType::TUPLE => format_tuple(&attr.values), + _ => Cow::from(""), + } + } + } + + node_names.insert(Cow::from(node_name)); + + for source in &node.input { + let mut ps = source.try_strip_prefix(prefix); + if !node_names.contains(&Cow::from(ps)) { + ps = node_name_map.get(ps).unwrap_or(&"".into()); + } + if !ps.is_empty() { + input.push(Cow::from(ps)) + } + } + + for sink in &node.output { + let ps = sink.try_strip_prefix(prefix); + output.push(Cow::from(ps)) + } + + nodes.insert(Cow::from(node_name), Node { + name: Cow::from(node_name), + opType: Cow::from(real_op), + input, + output, + shapes, + }); +} + +#[inline] +fn format_tensor<'a>(tensor: &TensorProto) -> Cow<'a, str> { + let dims = tensor.dims.iter().map(|i| i.to_string().into()).collect::>().join(", "); + + format!("{:?}({})", tensor.dtype, dims).into() +} + +#[inline] +fn format_tensors<'a>(tensors: &Vec) -> Cow<'a, str> { + let mut s = String::new(); + + for tensor in tensors { + s.push_str(&*format_tensor(tensor)) + } + + s.into() +} + +#[inline] +fn format_tuple<'a>(values: &Vec) -> Cow<'a, str> { + format_tensors(&values[0].tensors) +} diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/str_ext.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/str_ext.rs new file mode 100644 index 000000000..29d24a118 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/str_ext.rs @@ -0,0 +1,77 @@ +use std::borrow::Cow; + +use smartstring::alias::String as SmartString; + +pub trait StrExt { + fn try_strip_prefix(&self, prefix: &str) -> &str; +} + +impl StrExt for &str { + #[inline] + fn try_strip_prefix(&self, prefix: &str) -> &str { + match self.strip_prefix(prefix) { + Some(ps) => ps, + _ => self, + } + } +} + +impl StrExt for SmartString { + #[inline] + fn try_strip_prefix(&self, prefix: &str) -> &str { + match self.as_str().strip_prefix(prefix) { + Some(ps) => ps, + _ => self, + } + } +} + +impl StrExt for Option { + #[inline] + fn try_strip_prefix(&self, prefix: &str) -> &str { + match self { + Some(s) => s.try_strip_prefix(prefix), + _ => "", + } + } +} + +impl StrExt for String { + #[inline] + fn try_strip_prefix(&self, prefix: &str) -> &str { + match self.as_str().strip_prefix(prefix) { + Some(ps) => ps, + _ => self, + } + } +} + +impl StrExt for Option { + #[inline] + fn try_strip_prefix(&self, prefix: &str) -> &str { + match self { + Some(s) => s.try_strip_prefix(prefix), + _ => "", + } + } +} + +impl<'a> StrExt for Cow<'a, str> { + #[inline] + fn try_strip_prefix(&self, prefix: &str) -> &str { + match self.strip_prefix(prefix) { + Some(s) => s, + _ => self, + } + } +} + +impl<'a> StrExt for Option> { + #[inline] + fn try_strip_prefix(&self, prefix: &str) -> &str { + match self { + Some(s) => s.try_strip_prefix(prefix), + None => "", + } + } +} diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/tests/test_strip_prefix.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/tests/test_strip_prefix.rs new file mode 100644 index 000000000..7f100fc81 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/tests/test_strip_prefix.rs @@ -0,0 +1,82 @@ +use std::borrow::Cow; +use smartstring::alias::String as SmartString; +use parser::StrExt; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_str_strip_prefix() { + let s = "hello world"; + assert_eq!(s.try_strip_prefix("hello "), "world"); + assert_eq!(s.try_strip_prefix("hello"), " world"); + assert_eq!(s.try_strip_prefix("world"), s); + } + + #[test] + fn test_smartstring_strip_prefix() { + let s = SmartString::from("hello world"); + assert_eq!(s.try_strip_prefix("hello "), "world"); + assert_eq!(s.try_strip_prefix("hello"), " world"); + assert_eq!(s.try_strip_prefix("world"), s.as_str()); + } + + #[test] + fn test_option_smartstring_strip_prefix_some() { + let s = Some(SmartString::from("hello world")); + assert_eq!(s.try_strip_prefix("hello "), "world"); + assert_eq!(s.try_strip_prefix("hello"), " world"); + assert_eq!(s.try_strip_prefix("world"), s.as_ref().unwrap()); + } + + #[test] + fn test_option_smartstring_strip_prefix_none() { + let s: Option = None; + assert_eq!(s.try_strip_prefix("hello"), ""); + } + + #[test] + fn test_string_strip_prefix() { + let s = String::from("hello world"); + assert_eq!(s.try_strip_prefix("hello "), "world"); + assert_eq!(s.try_strip_prefix("hello"), " world"); + assert_eq!(s.try_strip_prefix("world"), s.as_str()); + } + + #[test] + fn test_option_string_strip_prefix_some() { + let s = Some(String::from("hello world")); + assert_eq!(s.try_strip_prefix("hello "), "world"); + assert_eq!(s.try_strip_prefix("hello"), " world"); + assert_eq!(s.try_strip_prefix("world"), s.as_ref().unwrap()); + } + + #[test] + fn test_option_string_strip_prefix_none() { + let s: Option = None; + assert_eq!(s.try_strip_prefix("hello"), ""); + } + + #[test] + fn test_cow_strip_prefix() { + let s = Cow::from("hello world"); + assert_eq!(s.try_strip_prefix("hello "), "world"); + assert_eq!(s.try_strip_prefix("hello"), " world"); + assert_eq!(s.try_strip_prefix("world"), s.as_ref()); + } + + #[test] + fn test_option_cow_strip_prefix_some() { + let s = Some(Cow::from("hello world")); + assert_eq!(s.try_strip_prefix("hello "), "world"); + assert_eq!(s.try_strip_prefix("hello"), " world"); + assert_eq!(s.try_strip_prefix("world"), s.as_ref().unwrap()); + } + + #[test] + fn test_option_cow_strip_prefix_none() { + let s: Option> = None; + assert_eq!(s.try_strip_prefix("hello"), ""); + } +} -- Gitee