diff --git a/plugins/mindstudio-insight-plugins/ModelVis/Cargo.lock b/plugins/mindstudio-insight-plugins/ModelVis/Cargo.lock index 1b241206eadb6548fde9858ec464162d42575450..0b55074943a4f23f6c49e9017fba6702197eb7d3 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/Cargo.lock +++ b/plugins/mindstudio-insight-plugins/ModelVis/Cargo.lock @@ -4504,6 +4504,7 @@ dependencies = [ "tauri-build", "tauri-plugin-dialog", "tauri-plugin-shell", + "tokio", ] [[package]] diff --git a/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/Cargo.toml b/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/Cargo.toml index 2d91c8ba86c8f90d788c05b42fa71e76a91d99a7..c81a357748ef772f41228f5840c1d1667dba3834 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/Cargo.toml +++ b/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/Cargo.toml @@ -22,3 +22,4 @@ parser = { path = "../../rust/parser" } layout = { path = "../../rust/layout" } fsg = { path = "../../rust/fsg" } anyhow = { workspace = true } +tokio = "1.45.0" diff --git a/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/resources/bin.ts b/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/resources/bin.ts index 77a16282164eb0ca2b193373a74238b6e90dda7f..1f69d3654593be2e73315c73855351c15425ddc8 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/resources/bin.ts +++ b/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/resources/bin.ts @@ -7,67 +7,92 @@ const RS_PATH = join(CACHE_FOLDER, "rs.json") const RsData = await Bun.file(RS_PATH).json() -const {nodes: rawNodes, edges: rawEdges, subgraphs, nesting_map} = RsData +// 获取命令行参数 +const args = process.argv.slice(2); +const parseName = args[0]; -const nodes = [] -const edges = [] +main(); -for (const name of Object.keys(subgraphs)) { - nodes.push({ - v: name, - opType: "subgraph" - }) +function main() { + if (RsData.name === parseName) { + parse(RsData); + return; + } + recursive_parse_single_graph(RsData); } -for (const node of Object.values(rawNodes)) { - const n = { - v: node.name, - width: 100, - height: 30, - opType: node.opType, - parent: nesting_map[node.name] ?? null, +async function recursive_parse_single_graph(model) { + if (model.subgraphes === undefined) { + return false; + } + if (model.subgraphes[parseName] !== undefined) { + await parse(model.subgraphes[parseName]); + return true; + } + for (let subgraph of Object.values(model.subgraphes)) { + if (await recursive_parse_single_graph(subgraph)) { + return true; + } } + return false; +} - if (node.tensors) n.tensors = node.tensors +async function parse(model) { + const {nodes: rawNodes, edges: rawEdges} = model - nodes.push(n) -} + const nodes = [] + const edges = [] -for (const [source, target] of rawEdges) { - if (rawNodes[source] && rawNodes[target]) { - edges.push({ - v: source, - w: target, - }) + for (const node of Object.values(rawNodes)) { + const n = { + v: node.name, + width: 100, + height: 30, + opType: node.opType, + parent: null, // 图的父图,已不传多图嵌套的情况,因此始终为 null + } + + if (node.tensors) n.tensors = node.tensors + + nodes.push(n) + } + + for (const [source, target] of rawEdges) { + if (rawNodes[source] && rawNodes[target]) { + edges.push({ + v: source, + w: target, + }) + } } -} -layoutAlg(nodes, edges, {nodesep: 20.0, edgesep: 20.0, ranksep: 20.0}, {}) + layoutAlg(nodes, edges, {nodesep: 20.0, edgesep: 20.0, ranksep: 20.0}, {}) -const final = { - nodes: [], - edges: [] -} + const final = { + nodes: [], + edges: [] + } -for (const node of nodes) { - final.nodes.push({ - id: node.v, - opType: node.opType, - x: node.x - node.width / 2, - y: node.y - node.height / 2, - width: node.width, - height: node.height - }) -} + for (const node of nodes) { + final.nodes.push({ + id: node.v, + opType: node.opType, + x: node.x - node.width / 2, + y: node.y - node.height / 2, + width: node.width, + height: node.height + }) + } -for (const edge of edges) { - final.edges.push({ - source: edge.v, - target: edge.w, - points: edge.points - }) -} + for (const edge of edges) { + final.edges.push({ + source: edge.v, + target: edge.w, + points: edge.points + }) + } -const TS_PATH = join(CACHE_FOLDER, "ts.json") + const TS_PATH = join(CACHE_FOLDER, "ts.json") -await Bun.write(TS_PATH, JSON.stringify(final)) + await Bun.write(TS_PATH, JSON.stringify(final)) +} diff --git a/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/src/commands.rs b/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/src/commands.rs index 21f88f8885205182241dafb4b87bb300499b8b21..e2dea5c9d3335981c8fee1188fc55b16a3678148 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/src/commands.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/app/src-tauri/src/commands.rs @@ -1,17 +1,18 @@ use std::{env::home_dir, fs::{create_dir, File}, io, io::{BufReader, BufWriter}, path::Path, time::Instant}; - +use std::path::PathBuf; use ahash::{HashMap, HashMapExt}; use anyhow::Result; use layout::{layout, Graph, GraphEdge, GraphNode, Key, KeyCodecExt}; -use parser::{parse_bin, Model}; +use parser::{parse_bin, Model, StdString}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use smartstring::alias::String; use fsg::{result::JSONResult, fsgs_bin}; -use tauri::{path::BaseDirectory, AppHandle, Manager, Result as InvokeResult}; +use tauri::{path::BaseDirectory, AppHandle, Emitter, Manager, Result as InvokeResult}; +use tauri::async_runtime::spawn_blocking; use tauri_plugin_shell::ShellExt; #[allow(non_snake_case)] -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct RenderNode { id: String, x: i32, @@ -100,7 +101,7 @@ fn calc_edge_bounding(points: &[Point]) -> BoundsArray { ) } -#[derive(Serialize)] +#[derive(Clone, Serialize)] struct RenderEdge { x: i32, y: i32, @@ -178,15 +179,25 @@ fn layout_model(model: &Model) -> (Vec, Vec) { (nodes, edges) } -#[derive(Serialize)] +const PARSE_GRAPH_EVENT: &str = "parse_graph_success"; + +#[derive(Clone, Serialize)] pub struct LayoutRet { model: Model, nodes: Vec, edges: Vec, + subgraphes: HashMap, +} + +#[derive(Serialize)] +pub struct GraphLayer { + name: String, + paths: Vec, + children: Vec, } #[tauri::command] -pub async fn layout_bin(handle: AppHandle, path: String) -> InvokeResult { +pub fn layout_bin(handle: AppHandle, path: String) -> InvokeResult { let script_path = handle.path().resolve("resources/bin.ts", BaseDirectory::Resource)?; let s1 = Instant::now(); @@ -202,19 +213,90 @@ pub async fn layout_bin(handle: AppHandle, path: String) -> InvokeResult) -> GraphLayer { + let name = model.name.clone(); + parent_list.push(name.clone()); + let layer = if model.subgraphes.is_empty() { + GraphLayer { + name, + paths: parent_list.clone(), + children: vec![], + } + } else { + GraphLayer { + name: model.name.clone(), + paths: parent_list.clone(), + children: model.subgraphes.values() + .map(|graph| layer_recursive(graph, parent_list)) + .collect(), + } + }; + parent_list.pop(); + layer +} +fn layout_recursive( + handle: &AppHandle, + tmp_folder: &PathBuf, + model: Model, + ts_path: &str, +) { + // 首先处理当前模型 + let simple_model = Model { + name: model.name, + nodes: model.nodes, + edges: model.edges, + parameters: model.parameters, + subgraphes: HashMap::new(), + }; + + layout_single_graph(handle, tmp_folder, simple_model, ts_path); + + // 递归处理每个子图 + for (_, subgraph_model) in model.subgraphes { + layout_recursive(handle, tmp_folder, subgraph_model, ts_path); + } +} + + +fn layout_single_graph(handle: &AppHandle, tmp_folder: &PathBuf, model: Model, ts_path: &str) { let s3 = Instant::now(); - let bun_command = handle.shell().sidecar("bun").unwrap().args([ts_path]); - let _ = bun_command.output().await; - println!("layout costs {:?}", s3.elapsed()); + + // 直接使用 sidecar API + let bun_command = handle.shell().sidecar("bun").unwrap().args([ts_path, &model.name]); + + // 使用本地 Runtime 阻塞等待(安全) + let rt = tokio::runtime::Runtime::new().unwrap(); + let output = rt.block_on(bun_command.output()) + .expect("Failed to run bun"); + if !output.status.success() { + panic!("Bun failed: {}", StdString::from_utf8_lossy(&output.stderr)); + } + println!("layout costs {:?} {:?}", &model.name, s3.elapsed()); let s4 = Instant::now(); let g = tmp_folder.join("ts.json"); let js_ret: JSLayoutRet = read_from(&g); - println!("decode costs {:?}", s4.elapsed()); + println!("> e:{}, v:{}", js_ret.edges.len(), js_ret.nodes.len()); + println!("decode costs {:?} {:?}", &model.name, s4.elapsed()); - Ok(merge_ret(js_ret, model)) + let result = merge_ret(js_ret, model); + // 发送单图解析结果事件到前端 + handle.emit(PARSE_GRAPH_EVENT, result).unwrap(); } fn write_to_json>(data: &T, path: P) { @@ -295,6 +377,7 @@ fn merge_ret(jsr: JSLayoutRet, model: Model) -> LayoutRet { nodes: jsr.nodes.into_iter().collect(), edges: jsr.edges.into_iter().collect(), model, + subgraphes: HashMap::new(), } } diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/model.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/model.rs index 6170f1777a4cda5a26b906ec311616adcb85f256..79eb8f6d41323eb6679467881ceeb4a3f1d0d88a 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/model.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/model.rs @@ -19,14 +19,13 @@ use smartstring::alias::String; use self::AttrValue::*; -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Clone)] pub struct Model { pub name: String, pub nodes: HashMap, pub edges: Vec<(String, String)>, pub parameters: HashMap, - pub subgraphs: HashMap, - pub nesting_map: HashMap, + pub subgraphes: HashMap, } #[derive(Debug, Serialize)] @@ -45,7 +44,7 @@ pub struct Tensor { } #[allow(non_snake_case)] -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Clone)] pub struct Node { pub name: String, pub opType: String, @@ -60,7 +59,7 @@ pub struct Node { pub dynamic: bool, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum AttrValue { StringLike(String), StringLikeArray(Vec), diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/pbtxt/onnx.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/pbtxt/onnx.rs index 58dd52bd85f6dacc812f8c97d564692ccc494482..b096d78c68fcf2934d18f7219697e37406f92d5d 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/pbtxt/onnx.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/pbtxt/onnx.rs @@ -28,12 +28,7 @@ use onnx::{ use smartstring::alias::String; use super::parse_pbtxt; -use crate::{ - AttrValue, - AttrValue::{StringLike, TensorVal, TensorVals}, - Model, Node, SmartStringExt, Subgraph, Tensor, -}; - +use crate::{AttrValue, AttrValue::{StringLike, TensorVal, TensorVals}, Model, Node, SmartStringExt, Tensor}; pub fn parse_onnx_pbtxt(path: &str) -> Result { parse_pbtxt::(path).map(|m| m.into()) } @@ -45,31 +40,27 @@ impl From for Model { fn from(value: ModelProto) -> Self { let Some(graph) = value.graph.0 else { unreachable!() }; - let mut ctx = GraphParsingContext::new(); - ctx.preprocess(&graph); - ctx.process_io(); - ctx.process_graph(&graph); - ctx.attach_tensor(); - - return Model { - name: String::from(&graph.name), - nodes: ctx.nodes.into_iter().map(|(k, v)| (String::from(k), v)).collect(), - edges: ctx.edges.into_iter().map(|(s, t)| (String::from(s), String::from(t))).collect(), - parameters: HashMap::new(), - subgraphs: ctx - .subgraphs - .into_iter() - .map(|k| (String::from(k), Subgraph { name: String::from(k) })) - .collect(), - nesting_map: ctx - .nesting_map - .into_iter() - .map(|(c, p)| (String::from(c), String::from(p))) - .collect(), - }; + let mut ctx = RecursiveGraphParsingContext::new(); + ctx.process_all(&graph); + + transform_model(ctx) } } +fn transform_model(ctx: RecursiveGraphParsingContext) -> Model { + return Model { + name: String::from(ctx.name), + nodes: ctx.nodes.into_iter().map(|(k, v)| (String::from(k), v)).collect(), + edges: ctx.edges.into_iter().map(|(s, t)| (String::from(s), String::from(t))).collect(), + parameters: HashMap::new(), + subgraphes: ctx + .subgraphes + .into_iter() + .map(|k| (String::from(k.name), transform_model(*k))) + .collect(), + }; +} + /// # Design Rationale: GraphParsingContext and Memory Management /// /// The `GraphParsingContext` struct uses references (`&'a T`) for all its @@ -87,48 +78,45 @@ impl From for Model { /// - **Lifetime Alignment**: /// All parsed elements share the same lifetime (`'a`) tied to the input `GraphProto`, /// ensuring safe access without runtime overhead. -struct GraphParsingContext<'a> { +struct RecursiveGraphParsingContext<'a> { + name: &'a str, raw_nodes: Vec<&'a NodeProto>, nodes: HashMap<&'a str, Node>, edges: Vec<(&'a str, &'a str)>, tensors: HashMap<&'a str, Tensor>, - subgraphs: Vec<&'a str>, + subgraphes: Vec>>, input_count_map: HashMap<&'a str, i32>, output_count_map: HashMap<&'a str, i32>, - nesting_map: HashMap<&'a str, &'a str>, } -impl<'a> GraphParsingContext<'a> { +impl<'a> RecursiveGraphParsingContext<'a> { fn new() -> Self { Self { + name: "", raw_nodes: vec![], nodes: HashMap::new(), edges: Vec::new(), tensors: HashMap::new(), - subgraphs: vec![], + subgraphes: vec![], input_count_map: HashMap::new(), output_count_map: HashMap::new(), - nesting_map: HashMap::new(), } } + pub fn process_all(&mut self, graph: &'a GraphProto) { + self.name = &graph.name; + self.preprocess(&graph); + self.process_io(); + self.process_graph(&graph); + self.attach_tensor(); + } + + /// 平铺所有非子图节点到 raw_nodes fn preprocess(&mut self, g: &'a GraphProto) { use AttributeType::*; for node in &g.node { self.raw_nodes.push(node); - - for attr in &node.attribute { - match attr.type_.enum_value_or_default() { - GRAPH => - if let Some(graph) = attr.g.as_ref() { - self.subgraphs.push(&node.name); - self.preprocess(graph) - }, - GRAPHS => attr.graphs.iter().for_each(|g| self.preprocess(g)), - _ => {} - } - } } } @@ -159,7 +147,7 @@ impl<'a> GraphParsingContext<'a> { let mut is_subgraph = false; let mut attributes = HashMap::new(); for attr in &node.attribute { - if let Some(value) = self.process_attr(attr, &node.name) { + if let Some(value) = self.process_attr(attr) { attributes.insert(String::from(&attr.name), value); } else { is_subgraph = true @@ -205,7 +193,7 @@ impl<'a> GraphParsingContext<'a> { } } - fn process_attr(&mut self, attr: &'a AttributeProto, name: &'a str) -> Option { + fn process_attr(&mut self, attr: &'a AttributeProto) -> Option { use AttributeType::*; let discriminator: AttributeType = attr.type_.enum_value().unwrap_or_default(); @@ -219,10 +207,9 @@ impl<'a> GraphParsingContext<'a> { TENSOR => value = TensorVal(TensorFormatter::fmt(attr.t.as_ref()?)), GRAPH => { let graph = attr.g.as_ref()?; - for subnode in &graph.node { - self.nesting_map.insert(&subnode.name, name); - self.process_node(subnode); - } + let mut ctx = RecursiveGraphParsingContext::new(); + ctx.process_all(graph); + self.subgraphes.push(Box::from(ctx)); return None; } SPARSE_TENSOR => diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/geir.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/geir.rs index 09ff295de2ab1d59dc6d4705136742ec3a77c973..c49f09bcb8158795f6448dd42adccdd4fec93f1a 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/geir.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/geir.rs @@ -62,8 +62,7 @@ impl From for Model { nodes, edges: vec![], parameters, - subgraphs: HashMap::new(), - nesting_map: HashMap::new(), + subgraphes: HashMap::new(), } } } diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/mindir.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/mindir.rs index 1ec0ce3a5b6f9c0db34071def3421d78c49b1b29..0cccb321cbe495f9ab9181b3f533242703225d26 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/mindir.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/mindir.rs @@ -50,8 +50,7 @@ impl From for Model { nodes: ctx.nodes, edges: ctx.edges, parameters: ctx.parameters, - subgraphs: HashMap::new(), - nesting_map: HashMap::new(), + subgraphes: HashMap::new(), } } } diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/onnx.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/onnx.rs index 15995f734ae5ddc9e8af2063cc192670ff168a70..1d3a017e21263815bde716e7e8e3e72d5ed06cd5 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/onnx.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/parser/src/processors/onnx.rs @@ -22,13 +22,12 @@ use ahash::{HashMap, HashMapExt}; use anyhow::Result; use onnx::{ AttributeProto, GraphProto, ModelProto, NodeProto, SparseTensorProto, TensorProto, - attribute_proto::{AttributeType, AttributeType::*}, + attribute_proto::{AttributeType::*}, }; use smartstring::alias::String; use super::{TensorFormatter, format_tensors, parse_pb}; use crate::{AttrValue, AttrValue::*, Model, Node, SmartStringExt}; - pub fn parse_onnx(path: &str) -> Result { parse_pb::(path).map(|m| m.into()) } @@ -134,8 +133,7 @@ impl From for Model { nodes, edges, parameters, - subgraphs: HashMap::new(), - nesting_map: HashMap::new(), + subgraphes: HashMap::new(), } } }