diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/gspan.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/gspan.rs index 5b55e20e132d38c95b8a4d15fb768d4feb7636a9..c7fff3801de52cd8559d0cd7a73eecbe7d9e40a3 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/gspan.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/gspan.rs @@ -17,11 +17,11 @@ use std::fs::File; use std::io::{BufWriter, Write}; use std::usize; -use super::result::OutType; +use super::result::{OutSource, OutType}; pub struct GSpan { trans: Vec, // 图列表 - min_sup: usize, // Min support, 相同结果在不同图中出现的最小次数 + min_sup: usize, // 相同结果在不同图中出现的最小次数 inner_min_sup: usize, // 相同结构在同一图中出现的最小次数 max_pat_min: usize, // Minimum number of patterns(vertices) to be output max_pat_max: usize, // Maximum number of patterns(vertices) to be output @@ -36,10 +36,7 @@ impl GSpan { max_pat_min: usize, max_pat_max: usize, directed: bool, - out_type: OutType, ) -> GSpan { - let singleton = MaxDFSCodeGraphResult::get_instance(); - singleton.set_config(min_sup, inner_min_sup, max_pat_min, max_pat_max, out_type); GSpan { trans: graphs, min_sup, @@ -50,54 +47,22 @@ impl GSpan { } } - pub fn new_with_out_path( - graphs: Vec, - min_sup: usize, - inner_min_sup: usize, - max_pat_min: usize, - max_pat_max: usize, - directed: bool, - out_path: &str, - out_type: OutType, - ) -> GSpan { - let singleton = MaxDFSCodeGraphResult::get_instance(); - singleton.set_config(min_sup, inner_min_sup, max_pat_min, max_pat_max, out_type); - singleton.set_stream(BufWriter::new(File::create(out_path).unwrap())); - GSpan { - trans: graphs, - min_sup, - inner_min_sup, - max_pat_min, - max_pat_max, - directed, - } - } - - #[allow(dead_code)] - pub fn new_with_stream( - graphs: Vec, - min_sup: usize, - inner_min_sup: usize, - max_pat_min: usize, - max_pat_max: usize, - directed: bool, - output: W, - out_type: OutType, - ) -> GSpan { - let singleton = MaxDFSCodeGraphResult::get_instance(); - singleton.set_config(min_sup, inner_min_sup, max_pat_min, max_pat_max, out_type); - singleton.set_stream(output); - GSpan { - trans: graphs, - min_sup, - inner_min_sup, - max_pat_min, - max_pat_max, - directed, + pub fn run(&self, + out_type: OutType, // 输出类型 + out_source: Option>, // 输出源 + mut process: Option, // 过程数据输出位置 + ) -> (usize, MaxDFSCodeGraphResult) { + // 0. Prepare the Result + let mut result = MaxDFSCodeGraphResult::default(); + result.set_config(self.min_sup, self.inner_min_sup, self.max_pat_min, self.max_pat_max, out_type); + if let Some(out_source) = out_source { + match out_source { + OutSource::Path(path) => result.set_stream(BufWriter::new(File::create(path).unwrap())), + OutSource::Stream(stream) => result.set_stream(stream), + OutSource::Channel(sender) => result.set_channel(true, Some(sender)), + } } - } - pub fn run(&self, process: &mut Option) -> usize { // 1. Find single node frequent subgraph, if requested let mut single_vertex_graph_map: BTreeMap< usize, @@ -111,7 +76,7 @@ impl GSpan { let mut next_gid: usize = 0; self.print_frequent_single_vertex(&mut single_vertex_graph_map, &mut single_vertex_label_frequent_map, - &mut next_gid, process); + &mut next_gid, &mut process); // 3. Subgraphs > Vertices // root: [from_label][e_label][to_label] -> Projected @@ -144,12 +109,12 @@ impl GSpan { e_label_key.to_string(), to_label_key.to_string(), ); - self.sub_mining(to_label_value, &mut dfs_code, &mut next_gid, process); - dfs_code.pop_with_set_result(to_label_value); + self.sub_mining(to_label_value, &mut dfs_code, &mut next_gid, &mut process, &mut result); + dfs_code.pop_with_set_result(to_label_value, &mut result); } } } - next_gid + (next_gid, result) } fn find_frequent_single_vertex(&self, single_vertex_graph_map: &mut BTreeMap< @@ -228,7 +193,7 @@ impl GSpan { } fn sub_mining(&self, projected: &Projected, dfs_code: &mut DFSCode, - next_gid: &mut usize, process: &mut Option, + next_gid: &mut usize, process: &mut Option, result: &mut MaxDFSCodeGraphResult, ) { if self.should_stop_mining(projected, dfs_code, next_gid, process) { return; @@ -251,8 +216,8 @@ impl GSpan { for (e_label_key, e_label_value) in to_value.iter() { dfs_code.push(max_to_code, *to_key, Vertex::NIL_V_LABEL.to_string(), e_label_key.to_string(), Vertex::NIL_V_LABEL.to_string()); - self.sub_mining(e_label_value, dfs_code, next_gid, process); - dfs_code.pop_with_set_result(e_label_value); + self.sub_mining(e_label_value, dfs_code, next_gid, process, result); + dfs_code.pop_with_set_result(e_label_value, result); } } // .. forward @@ -261,19 +226,18 @@ impl GSpan { for (to_label_key, to_label_value) in e_label_value.iter() { dfs_code.push(*from_key, max_to_code + 1, Vertex::NIL_V_LABEL.to_string(), e_label_key.to_string(), to_label_key.to_string()); - self.sub_mining(to_label_value, dfs_code, next_gid, process); - dfs_code.pop_with_set_result(to_label_value); - + self.sub_mining(to_label_value, dfs_code, next_gid, process, result); + dfs_code.pop_with_set_result(to_label_value, result); } } } } - fn generate_next_root<'a>(&'a self, projected: &'a Projected<'a>, dfs_code: &DFSCode, + pub fn generate_next_root<'a>(&'a self, projected: &'a Projected<'a>, dfs_code: &DFSCode, min_rm_path: &Vec, min_label: &str, max_to_code: usize, - ) -> (BTreeMap>>>, + ) -> (BTreeMap>>>, BTreeMap>>) { - + // [from][e_label][to_label] -> Projected let mut new_fwd_root: BTreeMap>> = BTreeMap::new(); // [to][e_label] -> Projected @@ -331,6 +295,7 @@ impl GSpan { (new_fwd_root, new_bck_root) } + fn should_stop_mining(&self, projected: &Projected, dfs_code: &mut DFSCode, next_gid: &mut usize, process: &mut Option, ) -> bool { @@ -642,37 +607,35 @@ impl GSpan { #[cfg(test)] mod tests { - use crate::gspan::result::MaxDFSCodeGraphResult; - use super::*; #[test] fn test_run_single_graph() { // JSON 文件路径 - let filename = r#"json\single-graph.json"#; + let filename = r#"tests\json\single-graph.json"#; match Graph::graph_from_file(&filename, true) { Ok(graph) => { println!("{}", graph.to_str_repr(None)); - let gspan = GSpan::new_with_out_path( + let gspan = GSpan::new( vec![graph], 1, 2, 1, 10, true, - "out-single.txt", - OutType::TXT, ); - let subgraphs = gspan.run(&mut Some(BufWriter::new(File::create("out-process-single.txt").unwrap()))); - - let singleton = MaxDFSCodeGraphResult::get_instance(); + let (subgraphs, result) = gspan.run( + OutType::TXT, + Some(OutSource::Path("out-single.txt".to_string())), + Some(BufWriter::new(File::create("out-process-single.txt").unwrap())), + ); assert_eq!(8, subgraphs); - assert_eq!(5, singleton.get_value_len()); - assert_eq!(12, singleton.get_sum_subgraphs()); + assert_eq!(5, result.get_value_len()); + assert_eq!(12, result.get_sum_subgraphs()); }, Err(err) => { println!("Error: {}", err); @@ -683,7 +646,7 @@ mod tests { #[test] fn test_run_lenet_graph() { // JSON 文件路径 - let filename = r#"json\lenet.json"#; + let filename = r#"tests\json\lenet.json"#; match Graph::graph_from_file(&filename, true) { Ok(graph) => { @@ -692,24 +655,24 @@ mod tests { let file = File::create("out-t.txt").unwrap(); let buffered_writer = BufWriter::new(file); - let gspan = GSpan::new_with_stream( + let gspan = GSpan::new( vec![graph], 1, 2, 1, 10, true, - buffered_writer, - OutType::TXT, ); - let subgraphs = gspan.run(&mut Some(BufWriter::new(File::create("out-t-process.txt").unwrap()))); - - let singleton = MaxDFSCodeGraphResult::get_instance(); + let (subgraphs, result) = gspan.run( + OutType::TXT, + Some(OutSource::Stream(buffered_writer)), + Some(BufWriter::new(File::create("out-t-process.txt").unwrap())), + ); assert_eq!(10, subgraphs); - assert_eq!(2, singleton.get_value_len()); - assert_eq!(4, singleton.get_sum_subgraphs()); + assert_eq!(2, result.get_value_len()); + assert_eq!(4, result.get_sum_subgraphs()); }, Err(err) => { println!("Error: {}", err); diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/models/dfs_code.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/models/dfs_code.rs index 006a43b7d6a775751728010fed9da5e664ea67f4..f19b3b571eed94a77ee06d73e5868be477829d7e 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/models/dfs_code.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/models/dfs_code.rs @@ -40,12 +40,11 @@ impl DFSCode { .push(DFS::from(from, to, from_label, e_label, to_label)); } - pub fn pop_with_set_result(&mut self, projected: &Projected) -> Option { + pub fn pop_with_set_result(&mut self, projected: &Projected, result: &mut MaxDFSCodeGraphResult) -> Option { if !self.is_push_result { // 记录尽可能远的深度搜索的结果 - let singleton = MaxDFSCodeGraphResult::get_instance(); // println!("pop {} {} {:?}", singleton.min_sup, singleton.inner_min_sup, singleton.out); - self.is_push_result = singleton.add_value(self, projected); + self.is_push_result = result.add_value(self, projected); } return self.dfs_vec.pop(); } diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/models/graph.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/models/graph.rs index 538519a54aa11479d65c76fcfec1cd5a34b5aa0f..37ac91e12b4fb6d04b99ffd093e57c80b9669af9 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/models/graph.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/models/graph.rs @@ -206,7 +206,7 @@ mod tests { #[test] fn test_load_single_graph() { - let filename = r#"json\single-graph.json"#; + let filename = r#"tests\json\single-graph.json"#; match Graph::graph_from_file(&filename, true) { Ok(graph) => { @@ -220,7 +220,7 @@ mod tests { #[test] fn test_load_graph() { - let filename = r#"json\single-graph.json"#; + let filename = r#"tests\json\single-graph.json"#; match Graph::graph_from_file(&filename, true) { Ok(graph) => { diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/result.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/result.rs index cb96f46384a27bec40dd02cbaad9ec5431a38c52..ee797ba0bacf32fd78c8b3ef62cbb87b5a28615d 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/result.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/gspan/result.rs @@ -1,23 +1,31 @@ /* * Copyright (c), Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. */ +use std::{io::Write, sync::mpsc::Sender}; + use rustc_hash::FxHashSet; use serde::{Deserialize, Serialize}; -use std::io::Write; -use std::sync::{Arc, mpsc::Sender, Mutex, MutexGuard, OnceLock}; -use crate::gspan::misc::{inner_support, support}; -use crate::gspan::models::edge::Edge; -use crate::gspan::models::{dfs_code::DFSCode, projected::Projected}; -use crate::io::output::{Edge as OutputEdge, Instance, NodeId, Structure, Vertex as OutputVertex}; +use crate::{ + gspan::{ + misc::{inner_support, support}, + models::{dfs_code::DFSCode, edge::Edge, projected::Projected}, + }, + io::output::{Edge as OutputEdge, Instance, NodeId, Structure, Vertex as OutputVertex}, +}; #[derive(Debug, Clone)] -#[allow(dead_code)] pub enum OutType { TXT, JSON, } +pub enum OutSource { + Channel(Sender), + Stream(W), + Path(String), +} + #[derive(Debug, Serialize, Deserialize)] pub struct JSONResult { pub between_sup: usize, @@ -29,7 +37,7 @@ pub struct JSONResult { } #[derive(Default)] -struct MaxDFSCodeGraphResultInner { +pub struct MaxDFSCodeGraphResult { out: Option>, out_type: Option, channel: bool, // 如果 channel 为 true,则 out 流失效 @@ -41,181 +49,150 @@ struct MaxDFSCodeGraphResultInner { value: Vec<(DFSCode, Vec>)>, } -// 单例结构体 -pub struct MaxDFSCodeGraphResult { - inner: Arc>, -} - impl MaxDFSCodeGraphResult { - // 提供一个方法来获取唯一的实例 - pub fn get_instance() -> &'static MaxDFSCodeGraphResult { - static INSTANCE: OnceLock = OnceLock::new(); - INSTANCE.get_or_init(|| { - MaxDFSCodeGraphResult { - inner: Arc::new(Mutex::new(MaxDFSCodeGraphResultInner::default())) - } - }) - } pub fn set_config( - &self, + &mut self, min_sup: usize, inner_min_sup: usize, - max_pat_min: usize, // Minimum number of vertices - max_pat_max: usize, // Maximum number of vertices + max_pat_min: usize, // Minimum number of vertices + max_pat_max: usize, // Maximum number of vertices out_type: OutType, ) { - let mut guard = self.inner.lock().unwrap(); - guard.min_sup = min_sup; - guard.inner_min_sup = inner_min_sup; - guard.max_pat_min = max_pat_min; - guard.max_pat_max = max_pat_max; - guard.out_type = Some(out_type); + self.min_sup = min_sup; + self.inner_min_sup = inner_min_sup; + self.max_pat_min = max_pat_min; + self.max_pat_max = max_pat_max; + self.out_type = Some(out_type); } - pub fn add_value(&self, dfs_code: &DFSCode, projected: &Projected) -> bool { - let mut guard = self.inner.lock().unwrap(); + pub fn add_value(&mut self, dfs_code: &DFSCode, projected: &Projected) -> bool { // Check if the pattern is frequent enough, between graphs let sup: usize = support(projected); - if sup < guard.min_sup { + if sup < self.min_sup { return false; } // Check if the pattern is frequent enough, inner graph let (_min_inner_sup, max_inner_sup) = inner_support(projected); - if max_inner_sup < guard.inner_min_sup { + if max_inner_sup < self.inner_min_sup { return false; } // Check if the dfs_code vertices.len in [max_pat_min, max_pat_max] - if guard.max_pat_max >= guard.max_pat_min && dfs_code.count_node() > guard.max_pat_max { + if self.max_pat_max >= self.max_pat_min && dfs_code.count_node() > self.max_pat_max { return false; } - if guard.max_pat_min > 0 && dfs_code.count_node() < guard.max_pat_min { + if self.max_pat_min > 0 && dfs_code.count_node() < self.max_pat_min { return false; } let item = (dfs_code.clone(), projected.to_vertex_names_list()); let edges_list = projected.to_edges_list(); - if guard.channel { - self.send_result(&mut guard, sup, _min_inner_sup, max_inner_sup, &item, edges_list); - } else if Option::is_some(&guard.out) { - self.write_result(&mut guard, sup, _min_inner_sup, max_inner_sup, &item, edges_list); + if self.channel { + self.send_result(sup, _min_inner_sup, max_inner_sup, &item, edges_list); + } else if Option::is_some(&self.out) { + self.write_result(sup, _min_inner_sup, max_inner_sup, &item, edges_list); } - guard.value.push(item); + self.value.push(item); true } pub fn get_value_len(&self) -> usize { - let guard = self.inner.lock().unwrap(); - guard.value.len() + self.value.len() } pub fn get_result(&self) -> Vec { - let guard = self.inner.lock().unwrap(); - guard.value.iter().map(|v| { + self.value.iter().map(|v| { let instances = v.1.iter().map(|set| { - let node_ids= set.iter().map(|p| NodeId { gid: p.0, nid: p.1.clone() }).collect::>(); - return Instance { - node_num: node_ids.len(), - node_ids, - edges: vec![], - } + let node_ids = set.iter().map(|p| NodeId { gid: p.0, nid: p.1.clone() }).collect::>(); + return Instance { node_num: node_ids.len(), node_ids, edges: vec![] }; }).collect::>(); return JSONResult { between_sup: 0, inner_min_sup: 0, inner_max_sup: 0, total: instances.len(), - structure: Structure { - tid: 0, - vertices: vec![], - edges: vec![], - }, + structure: Structure { tid: 0, vertices: vec![], edges: vec![] }, instances, }; }).collect::>() } pub fn get_sum_subgraphs(&self) -> usize { - let guard = self.inner.lock().unwrap(); - guard.value.iter().map(|e| e.1.len()).sum() + self.value.iter().map(|e| e.1.len()).sum() } } impl MaxDFSCodeGraphResult { - pub fn set_channel(&self, channel: bool, sender: Option>) { - let mut guard = self.inner.lock().unwrap(); + pub fn set_channel(&mut self, channel: bool, sender: Option>) { if channel { - guard.out = None; + self.out = None; } else { // take 方法会返回 Some(sender),并把原来的字段设置为 None // _sender 离开作用域时会自动调用 drop,因此不需要显式调用 drop - let _sender = guard.sender.take(); + let _sender = self.sender.take(); } - guard.sender = sender; - guard.channel = channel; + self.sender = sender; + self.channel = channel; } - pub fn drop_sender(&self) { - let mut guard = self.inner.lock().unwrap(); - if Option::is_some(&guard.sender) { + pub fn drop_sender(&mut self) { + if Option::is_some(&self.sender) { // take 方法会返回 Some(sender),并把原来的字段设置为 None // _sender 离开作用域时会自动调用 drop,因此不需要显式调用 drop - let _sender = guard.sender.take(); + let _sender = self.sender.take(); } } fn send_result( - &self, - guard: &mut MutexGuard, + &mut self, sup: usize, min_inner_sup: usize, max_inner_sup: usize, item: &(DFSCode, Vec>), edges_list: Vec>, ) { - let id = guard.value.len(); - if let Some(out_type) = &guard.out_type { + let id = self.value.len(); + if let Some(out_type) = &self.out_type { match out_type { - OutType::TXT => { - if let Some(sender) = &mut guard.sender { - let line = report_txt(id, sup, min_inner_sup, max_inner_sup, item, edges_list); + OutType::TXT => + if let Some(sender) = &mut self.sender { + let line = + report_txt(id, sup, min_inner_sup, max_inner_sup, item, edges_list); sender.send(line).unwrap(); - } - } - OutType::JSON => { - if let Some(sender) = &mut guard.sender { - let line = report_json(id, sup, min_inner_sup, max_inner_sup, item, edges_list); + }, + OutType::JSON => + if let Some(sender) = &mut self.sender { + let line = + report_json(id, sup, min_inner_sup, max_inner_sup, item, edges_list); sender.send(line).expect("ERR: MaxDFSCodeGraphResult Channel"); - } - } + }, } } } } impl MaxDFSCodeGraphResult { - pub fn set_stream(&self, out: W) { - let mut guard = self.inner.lock().unwrap(); - guard.channel = false; - guard.out = Some(Box::new(out)); + pub fn set_stream(&mut self, out: W) { + self.channel = false; + self.out = Some(Box::new(out)); } fn write_result( - &self, - guard: &mut MutexGuard, + &mut self, sup: usize, min_inner_sup: usize, max_inner_sup: usize, item: &(DFSCode, Vec>), edges_list: Vec>, ) { - let id = guard.value.len(); - if let Some(out_type) = &guard.out_type { + let id = self.value.len(); + if let Some(out_type) = &self.out_type { match out_type { OutType::TXT => { - if let Some(out) = &mut guard.out { - let line = report_txt(id, sup, min_inner_sup, max_inner_sup, item, edges_list); + if let Some(out) = &mut self.out { + let line = + report_txt(id, sup, min_inner_sup, max_inner_sup, item, edges_list); out.write(&*line.into_bytes()).expect("ERR: MaxDFSCodeGraphResult Stream"); - + // 刷新缓冲区,确保所有数据都被写出 if let Err(e) = out.flush() { eprintln!("Failed to flush output after writing lines: {}", e); @@ -223,11 +200,12 @@ impl MaxDFSCodeGraphResult { } } OutType::JSON => { - if let Some(out) = &mut guard.out { - let line = report_json(id, sup, min_inner_sup, max_inner_sup, item, edges_list); + if let Some(out) = &mut self.out { + let line = + report_json(id, sup, min_inner_sup, max_inner_sup, item, edges_list); out.write(&*line.into_bytes()).expect("ERR: MaxDFSCodeGraphResult Stream"); out.write(b",\n").expect("ERR: MaxDFSCodeGraphResult Stream"); - + // 刷新缓冲区,确保所有数据都被写出 if let Err(e) = out.flush() { eprintln!("Failed to flush output after writing lines: {}", e); @@ -355,23 +333,23 @@ mod tests { #[test] fn test_singleton() { - // 获取单例实例 - let singleton = MaxDFSCodeGraphResult::get_instance(); + // 获取实例 + let mut default = MaxDFSCodeGraphResult::default(); // 创建文件并使用 BufWriter 来提高性能 let file_path = "out_test.json"; let file = File::create(file_path).unwrap(); let buffered_writer = BufWriter::new(file); - singleton.set_config(0, 0, 1, 4, OutType::JSON); - singleton.set_stream(buffered_writer); + default.set_config(0, 0, 1, 4, OutType::JSON); + default.set_stream(buffered_writer); - assert_eq!(0, singleton.get_value_len()); + assert_eq!(0, default.get_value_len()); let mut dfs_code = DFSCode::new(); create_and_add_edges(&mut dfs_code); let projected = Projected::new(); - singleton.add_value(&dfs_code, &projected); - assert_eq!(1, singleton.get_value_len()); + default.add_value(&dfs_code, &projected); + assert_eq!(1, default.get_value_len()); } } diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/lib.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/lib.rs index 2846e579d85a356fccfd919523f21ba446497688..0e1d97c0c2f620cad96f93b33e412ad9b92537d6 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/lib.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/lib.rs @@ -76,23 +76,3 @@ macro_rules! subgraph_command { subgraph_command!(subgraphs_mindir, parse_mindir_model); subgraph_command!(subgraphs_geir, parse_geir_model); subgraph_command!(subgraphs_onnx, parse_onnx_model); - -#[cfg(test)] -mod tests { - use std::env; - - use crate::subgraphs_onnx; - - #[test] - fn test_parse_then_get_subgraph() { - // 获取当前工作目录 - let current_dir = env::current_dir().expect("Failed to get current directory"); - let filename = current_dir.join(r#"onnx\light_bvlc_alexnet.onnx"#); - let path = filename.to_str().expect("Failed to convert path to string"); - match subgraphs_onnx(path, 2, 10) { - Some(subgraphs) => assert_eq!(14, subgraphs.len()), - None => println!("Error"), - } - } - -} \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/strategy/config.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/strategy/config.rs index d181e3ed9773ce65f7add64d04d2a00313ff3c29..3996497f20a31031652491cc3127b157a2e04150 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/strategy/config.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/strategy/config.rs @@ -18,7 +18,7 @@ pub enum ConfigError { MinVerticesGreaterThanMax(usize, usize), NonNormalizedPath(String), SymlinkNotAllowed(String), - IllegalFilePath(String), + IllegalCharacters(String), } #[derive(Debug)] @@ -183,7 +183,7 @@ impl Config { } } -// 校验路径是否已标准化(不允许是软链接,且存在文件) +// 校验路径是否已标准化(不允许是软链接,且没有非法字符) fn check_normalized_path(path: &str) -> Result<(), ConfigError> { let p = Path::new(path); @@ -192,14 +192,39 @@ fn check_normalized_path(path: &str) -> Result<(), ConfigError> { return Err(ConfigError::SymlinkNotAllowed(format!("Path '{}' is a symbolic link.", path))); } - // 检查是否存在文件 - if !p.is_file() { - return Err(ConfigError::IllegalFilePath(format!("Path '{}' contains illegal character '{}'.", path, illegal_char))); + // 检查是否存在非法字符 + if let Some(illegal_char) = find_illegal_characters(path) { + return Err(ConfigError::IllegalCharacters(format!("Path '{}' contains illegal character '{}'.", path, illegal_char))); } Ok(()) } +fn find_illegal_characters(path: &str) -> Option { + // 定义文件路径非法字符集合 + let illegal_chars = if cfg!(target_os = "windows") { + // Windows 非法字符 + vec!['<', '>', ':', '"', '|', '?', '*', '\0'] + .into_iter() + .chain((0x01..=0x1F).map(|c| c as u8 as char)) // 控制字符 + .collect() + } else { + // Unix-like 系统非法字符 + vec!['\0'] + .into_iter() + .chain((0x01..=0x1F).map(|c| c as u8 as char)) // 控制字符 + .collect::>() + }; + + for c in path.chars() { + if illegal_chars.contains(&c) { + return Some(c); + } + } + + None +} + // 测试代码 #[cfg(test)] mod tests { @@ -232,6 +257,6 @@ mod tests { 3, 5, ); - assert!(matches!(config_result.unwrap_err(), ConfigError::IllegalFilePath(_))); + assert!(matches!(config_result.unwrap_err(), ConfigError::IllegalCharacters(_))); } } \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/strategy/gspan_mining.rs b/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/strategy/gspan_mining.rs index b8a798567917e21bf53f7a1badeedbc11a3fb127..6a03dceaf989207bd3be8d0094bcd73b82800281 100644 --- a/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/strategy/gspan_mining.rs +++ b/plugins/mindstudio-insight-plugins/ModelVis/rust/subgraph/src/strategy/gspan_mining.rs @@ -8,8 +8,8 @@ use std::{ use crate::{gspan::{ gspan::GSpan, models::graph::Graph, - result::{JSONResult, MaxDFSCodeGraphResult, OutType}, -}, strategy::config::InputSource}; + result::{JSONResult, OutType}, +}, result::OutSource, strategy::config::InputSource}; use super::mining_strategy::MiningStrategy; @@ -35,36 +35,33 @@ impl MiningStrategy for GSpanMining { println!("Took {}ms", alpha); println!("Mining subgraphs.."); - let gspan = match args.get_output_path() { - Some(file) => GSpan::new_with_out_path(graphs, args.get_min_support(), - args.get_min_inner_support(), args.get_min_vertices(), args.get_max_vertices(), true, file, - args.get_output_type().clone() - ), - None => GSpan::new(graphs, args.get_min_support(), args.get_min_inner_support(), - args.get_min_vertices(), args.get_max_vertices(), true, args.get_output_type().clone(), - ), + let gspan = GSpan::new(graphs, args.get_min_support(), args.get_min_inner_support(), + args.get_min_vertices(), args.get_max_vertices(), true); + + let process_writer: Option> = match args.get_process_path() { + Some(file) => Some(BufWriter::new(File::create(file).unwrap())), + None => None, }; - let mut process_writer: Option> = match args.get_process_path() { + let output_source = match args.get_output_path() { + Some(file) => Some(OutSource::Path(file.to_string())), None => None, - Some(file) => Some(BufWriter::new(File::create(file).unwrap())), }; - let subgraphs = gspan.run(&mut process_writer); + let (subgraphs, result) = gspan.run(args.get_output_type().clone(), output_source, process_writer); let delta = now.elapsed().as_millis(); println!("Finished."); println!("Found {} subgraphs", subgraphs); - let singleton = MaxDFSCodeGraphResult::get_instance(); println!( "Found {}/{} subgraphs (Only Max)", - singleton.get_value_len(), - singleton.get_sum_subgraphs() + result.get_value_len(), + result.get_sum_subgraphs() ); println!("Took {}ms", delta - alpha); println!("Total Took {}ms", delta); fix_json_file(args.get_output_path(), args.get_output_type()); - singleton.get_result() + result.get_result() } fn run_channel(&self, args: super::Config) -> Receiver { @@ -86,39 +83,32 @@ impl MiningStrategy for GSpanMining { println!("Took {}ms", alpha); println!("Mining subgraphs.."); - let gspan = match args.get_output_path() { - Some(file) => GSpan::new_with_out_path(graphs, args.get_min_support(), - args.get_min_inner_support(), args.get_min_vertices(), args.get_max_vertices(), true, file, - args.get_output_type().clone(), - ), - None => GSpan::new(graphs, args.get_min_support(), args.get_min_inner_support(), - args.get_min_vertices(), args.get_max_vertices(), true, args.get_output_type().clone(), - ), - }; + let gspan = GSpan::new(graphs, args.get_min_support(), args.get_min_inner_support(), + args.get_min_vertices(), args.get_max_vertices(), true); + let (tx, rx): (Sender, Receiver) = mpsc::channel(); let process_path = (*args.get_process_path()).clone(); + let output_type = args.get_output_type().clone(); thread::spawn(move || { - let singleton = MaxDFSCodeGraphResult::get_instance(); - singleton.set_channel(true, Some(tx)); - - let mut process_writer: Option> = match process_path { + + let process_writer: Option> = match process_path { None => None, Some(file) => Some(BufWriter::new(File::create(file).unwrap())), }; - let subgraphs = gspan.run(&mut process_writer); + let (subgraphs, mut result) = gspan.run(output_type, Some(OutSource::Channel(tx)), process_writer); let delta = now.elapsed().as_millis(); println!("Finished."); println!("Found {} subgraphs", subgraphs); println!( "Found {}/{} subgraphs (Only Max)", - singleton.get_value_len(), - singleton.get_sum_subgraphs() + result.get_value_len(), + result.get_sum_subgraphs() ); println!("Took {}ms", delta - alpha); println!("Total Took {}ms", delta); - singleton.drop_sender(); + result.drop_sender(); }); fix_json_file(args.get_output_path(), args.get_output_type()); @@ -159,13 +149,13 @@ fn fix_json_file(output_path: &Option, output_type: &OutType) { #[cfg(test)] mod tests { use super::*; - use crate::gspan::result::{MaxDFSCodeGraphResult, OutType}; + use crate::gspan::result::OutType; use crate::strategy::Config; #[test] fn test_run_lenet_graph() { // JSON 文件路径 - let filename = r#"json\lenet.json"#; + let filename = r#"tests\json\lenet.json"#; let gspan_mining = GSpanMining; @@ -182,10 +172,8 @@ mod tests { Ok(config) => { let result = gspan_mining.run(config); - let singleton = MaxDFSCodeGraphResult::get_instance(); - - assert_eq!(2, singleton.get_value_len()); - assert_eq!(4, singleton.get_sum_subgraphs()); + assert_eq!(2, result.len()); + assert_eq!(4, result.iter().map(|r| r.instances.len()).sum::()); println!("{:?}", result); } Err(e) => eprintln!("Failed to create config: {:?}", e), @@ -195,7 +183,7 @@ mod tests { #[test] fn test_run_lenet_graph_parsed() { // JSON 文件路径 - let graph = Graph::graph_from_file(r#"json\lenet.json"#, true).unwrap(); + let graph = Graph::graph_from_file(r#"tests\json\lenet.json"#, true).unwrap(); let gspan_mining = GSpanMining; @@ -212,10 +200,8 @@ mod tests { Ok(config) => { let result = gspan_mining.run(config); - let singleton = MaxDFSCodeGraphResult::get_instance(); - - assert_eq!(2, singleton.get_value_len()); - assert_eq!(4, singleton.get_sum_subgraphs()); + assert_eq!(2, result.len()); + assert_eq!(4, result.iter().map(|r| r.instances.len()).sum::()); println!("{:?}", result); } Err(e) => eprintln!("Failed to create config: {:?}", e), @@ -225,7 +211,7 @@ mod tests { #[test] fn test_run_channel_lenet_graph() { // JSON 文件路径 - let filename = r#"json\lenet.json"#; + let filename = r#"tests\json\lenet.json"#; let gspan_mining = GSpanMining;