From fba95f1a4d97780c55e3203fb08c807fac4edf45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A9=BA=E7=99=BD?= <3440771474@qq.com> Date: Tue, 29 Apr 2025 22:21:25 +0800 Subject: [PATCH] feat: support nesting graph Signed-off-by: nightwalk --- .../ModelVis/rust/parser/src/pbtxt/onnx.rs | 293 ++++++++++++++++++ 1 file changed, 293 insertions(+) create mode 100644 plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/pbtxt/onnx.rs diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/pbtxt/onnx.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/pbtxt/onnx.rs new file mode 100644 index 0000000000..a6d3a64b13 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/pbtxt/onnx.rs @@ -0,0 +1,293 @@ +mod onnx { + include!("./gen.rs"); +} + +use ahash::{HashMap, HashMapExt}; +use onnx::{ + AttributeProto, GraphProto, ModelProto, NodeProto, SparseTensorProto, TensorProto, + attribute_proto::AttributeType, +}; +use smartstring::alias::String; + +use super::parse_pbtxt; +use crate::{ + AttrValue, + AttrValue::{StringLike, TensorVal, TensorVals}, + Model, Node, SmartStringExt, Subgraph, +}; + +pub fn parse_onnx_pbtxt(path: &str) -> Option { + match parse_pbtxt::(path) { + Ok(model) => model.into(), + Err(_) => None, + } +} + +impl From for Option { + fn from(value: ModelProto) -> Self { + if let Some(graph) = value.graph.0 { + let mut ctx = GraphParsingContext::new(); + ctx.preprocess(&graph); + ctx.process_io(); + ctx.process_graph(&graph); + + return Some(Model { + name: String::from(&graph.name), + nodes: ctx.nodes.into_iter().map(|(k, v)| (String::from(k), v)).collect(), + edges: ctx + .edges + .into_iter() + .map(|(s, t)| (String::from(s), String::from(t))) + .collect(), + parameters: HashMap::new(), + subgraphes: ctx + .subgraphes + .into_iter() + .map(|k| (String::from(k), Subgraph { name: String::from(k) })) + .collect(), + nesting_map: ctx + .nesting_map + .into_iter() + .map(|(c, p)| (String::from(c), String::from(p))) + .collect(), + }); + } + + None + } +} + +struct GraphParsingContext<'a> { + raw_nodes: Vec<&'a NodeProto>, + nodes: HashMap<&'a str, Node>, + edges: Vec<(&'a str, &'a str)>, + subgraphes: Vec<&'a str>, + input_count_map: HashMap<&'a str, i32>, + output_count_map: HashMap<&'a str, i32>, + nesting_map: HashMap<&'a str, &'a str>, +} + +impl<'a> GraphParsingContext<'a> { + fn new() -> Self { + Self { + raw_nodes: vec![], + nodes: HashMap::new(), + edges: Vec::new(), + subgraphes: vec![], + input_count_map: HashMap::new(), + output_count_map: HashMap::new(), + nesting_map: HashMap::new(), + } + } + + fn preprocess(&mut self, g: &'a GraphProto) { + use AttributeType::*; + + for node in &g.node { + self.raw_nodes.push(node); + + for attr in &node.attribute { + match attr.type_.enum_value_or_default() { + GRAPH => + if let Some(graph) = attr.g.as_ref() { + self.subgraphes.push(&node.name); + self.preprocess(graph) + }, + GRAPHS => { + // attr.graphs.iter().for_each(|g| self.preprocess(g)) + } + _ => {} + } + } + } + + // for i in &g.input { + // println!("graph input: {}, {:?}", i.name, self.input_count_map.get(i.name.as_str())); + // self.input_count_map.remove(i.name.as_str()); + // } + + // for o in &g.output { + // println!("graph output: {}, {:?}", o.name, self.output_count_map.get(o.name.as_str())); + // self.output_count_map.remove(o.name.as_str()); + // } + } + + fn process_io(&mut self) { + for raw_node in &self.raw_nodes { + for i in &raw_node.input { + self.input_count_map.entry(i).and_modify(|x| *x += 1).or_insert(1); + } + + for o in &raw_node.output { + self.output_count_map.entry(o).and_modify(|x| *x += 1).or_insert(1); + } + } + } + + fn process_graph(&mut self, graph: &'a GraphProto) { + for node in &graph.node { + self.process_node(node) + } + } + + fn process_node(&mut self, node: &'a NodeProto) { + if self.handle_constant(node) { + return; + } + + let mut dynamic = false; + let mut is_subgraph = false; + let mut attributes = HashMap::new(); + for attr in &node.attribute { + if let Some(value) = self.process_attr(attr, &node.name) { + attributes.insert(String::from(&attr.name), value); + } else { + is_subgraph = true + } + + if attr.is_dyn_shape() { + dynamic = true + } + } + + if !is_subgraph { + let name: &'a str = &node.name; + #[allow(non_snake_case)] + let opType = String::from(&node.op_type); + let input = String::from_slice(&node.input); + let output = String::from_slice(&node.output); + + self.nodes.insert( + name, + Node { name: String::from(name), opType, input, output, attributes, dynamic }, + ); + + self.collect_edges(node); + } + } + + fn collect_edges(&mut self, node: &'a NodeProto) { + let name = &node.name; + for output in &node.output { + for rn in &self.raw_nodes { + if rn.input.contains(&output) && name != &rn.name { + self.edges.push((name, &rn.name)); + } + } + } + } + + fn process_attr(&mut self, attr: &'a AttributeProto, name: &'a str) -> Option { + use AttributeType::*; + + let discriminator: AttributeType = attr.type_.enum_value().unwrap_or_default(); + + let value: AttrValue; + match discriminator { + UNDEFINED => value = StringLike(String::from("")), + FLOAT => value = StringLike(String::from_f32(attr.f)), + INT => value = StringLike(String::from_i64(attr.i)), + STRING => value = StringLike(String::from_vecu8(&attr.s)), + TENSOR => value = TensorVal(TensorFormatter::fmt(attr.t.as_ref()?)), + GRAPH => { + let graph = attr.g.as_ref()?; + for subnode in &graph.node { + self.nesting_map.insert(&subnode.name, name); + self.process_node(subnode); + } + return None; + } + SPARSE_TENSOR => + value = StringLike(String::from_i64s(&attr.sparse_tensor.as_ref()?.dims)), + FLOATS => value = StringLike(String::from_f32s(&attr.floats)), + INTS => value = StringLike(String::from_i64s(&attr.ints)), + STRINGS => value = StringLike(String::from_2dvecu8(&attr.strings)), + TENSORS => value = format_tensors(&attr.tensors), + GRAPHS => unimplemented!(), + SPARSE_TENSORS => value = format_tensors(&attr.sparse_tensors), + }; + + Some(value) + } + + fn handle_constant(&mut self, node: &NodeProto) -> bool { + let constant = self.is_constant_node(node); + + let attribute = if constant { node.attribute.first() } else { None }; + + if let Some(attribute) = attribute { + if attribute.is_tensor() { + return true; + } + + if attribute.is_sparse_tensor() { + return true; + } + } + + false + } + + fn is_constant_node(&self, node: &NodeProto) -> bool { + node.op_type == "Constant" + && node.attribute.len() == 1 + && node.input.len() == 0 + && node.output.len() == 1 + && self.input_count_map.get(node.output[0].as_str()) == Some(&1) + && self.output_count_map.get(node.output[0].as_str()) == Some(&1) + } +} + +impl AttributeProto { + fn is_tensor(&self) -> bool { + self.name == "value" + && self.type_.enum_value() == Ok(AttributeType::TENSOR) + && self.t.is_some() + } + + fn is_sparse_tensor(&self) -> bool { + self.name == "sparse_value" + && self.type_.enum_value() == Ok(AttributeType::SPARSE_TENSOR) + && self.sparse_tensor.is_some() + } + + fn is_dyn_shape(&self) -> bool { + self.i == 1 && (self.name == "_is_unknown_shape" || self.name == "_force_unknown_shape") + } + + fn is_subgraph(&self) -> bool { + match self.type_.enum_value_or_default() { + AttributeType::GRAPH | AttributeType::GRAPHS => true, + _ => false, + } + } +} + +impl TensorFormatter for TensorProto { + fn fmt(&self) -> String { + let dims = String::from_i64s(&self.dims); + + format!("{:?}({})", self.dtype.enum_value().unwrap(), dims).into() + } +} + +impl TensorFormatter for SparseTensorProto { + fn fmt(&self) -> String { + String::from_i64s(&self.dims) + } +} + +trait TensorFormatter { + fn fmt(&self) -> String; +} + +#[inline] +fn format_tensors(tensors: &[T]) -> AttrValue +where + T: TensorFormatter, +{ + if tensors.len() < 2 { + return StringLike(tensors[0].fmt()); + } + TensorVals(tensors.iter().map(|t| t.fmt()).collect()) +} -- Gitee