diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/controllers/hierarchy.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/controllers/hierarchy.py index acb60aeae93ab597cf50b4fd097969ffe5eeedc8..1a83932002196f47ece21cc99069fac231469eb3 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/controllers/hierarchy.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/controllers/hierarchy.py @@ -53,11 +53,17 @@ class Hierarchy: def extract_label_name(node_name, node_type): splited_subnode_name = node_name.split('.') splited_label = [] + # 在展开层级时,将父级层级名称相关去除,仅保留子节点本身名称信息 + # 如Module.layer1.1.relu.ReLU.forward.0中的父级名称Module.layer1.1去除,仅保留子级的relu.ReLU.forward.1 + # 如Module.layer4.0.BasicBlock.forward.0中的父级名称Module.1去除,仅保留子级的layer4.0.BasicBlock.forward.0 if node_type == MODULE: if len(splited_subnode_name) < 4: return node_name splited_label = splited_subnode_name[-4:] if not splited_subnode_name[ -4].isdigit() else splited_subnode_name[-5:] + # 在展开层级时,将父级层级名称相关去除,仅保留API子节点本身名称信息, + # 如 Module.layer1.1.ApiList.1 中的父级名称Module.layer1.1去除,仅保留子级的ApiList.1 + # 如 Module.layer1.1.ApiList.0.1 中的父级名称Module.layer1.1去除,仅保留子级的ApiList.0.1 else: if len(splited_subnode_name) < 2: return node_name diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/graph_service.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/graph_service.py index 70ef0b2172a9f2cfe9f4dbaa8fdfffb25032cda9..36c3aa157915ad06b0ba625864eaab7432bb9d5a 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/graph_service.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/graph_service.py @@ -135,7 +135,7 @@ class GraphService: config['colors'] = config_data_run.get('colors') # 读取目录下配置文件列表 config_files = GraphUtils.find_config_files(run) - config['matchedConfigFiles'] = config_files + config['matchedConfigFiles'] = config_files or [] config['task'] = graph_data.get('task') return {'success': True, 'data': config} except Exception as e: @@ -329,7 +329,7 @@ class GraphService: config_data_run['colors'] = colors config_data[run] = config_data_run GraphState.set_global_value("config_data", config_data) - GraphUtils.safe_save_data(first_file_data, run, first_run_tag) + GraphUtils.safe_save_data(first_file_data, run, f"{first_run_tag}.vis") return {'success': True, 'error': None, 'data': {}} except Exception as e: return {'success': False, 'error': str(e), 'data': None} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/global_state.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/global_state.py index 979e5b4f34078941333f4f5aa07cacb4309e5b33..6e4fcabaeefeca8eceaa599bb4d8de82705ee041 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/global_state.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/global_state.py @@ -36,6 +36,9 @@ SINGLE = 'Single' # 前端节点类型 EXPAND_MODULE = 0 UNEXPAND_NODE = 1 +# 权限码 +PERM_GROUP_WRITE = 0o020 +PERM_OTHER_WRITE = 0o002 # 后端节点类型 MODULE = 0 diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py index 41a86a45bfbee8da730106a1aee3668ba02c8f2e..b5efe347bb454db8c0d529f390d1b914c049038c 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py @@ -19,12 +19,14 @@ import os import json import re import stat +import sys from functools import cmp_to_key from pathlib import Path from tensorboard.util import tb_logging -from .global_state import GraphState, FILE_NAME_REGEX, MAX_FILE_SIZE +from .global_state import GraphState, FILE_NAME_REGEX, MAX_FILE_SIZE, PERM_GROUP_WRITE, PERM_OTHER_WRITE logger = tb_logging.get_logger() +FILE_PATH_MAX_LENGTH = 4096 class GraphUtils: @@ -214,35 +216,24 @@ class GraphUtils: def safe_save_data(data, run_name, tag): runs = GraphState.get_global_value('runs', {}) run = runs.get(run_name) or run_name - safe_base_dir = GraphState.get_global_value('logdir') if run is None or tag is None: error_message = 'The query parameters "run" and "tag" are required' return None, error_message try: - # 构建文件路径并标准化 - file_path = os.path.join(run, tag) - file_path = os.path.normpath(file_path) - # 权限校验:检查目录是否有写权限 - if not os.access(run, os.W_OK): - raise PermissionError(f"No write permission for directory: {run}\n") - # 检查 run 目录是否存在,如果不存在则创建 - if not os.path.exists(run): - os.makedirs(run, exist_ok=True) - os.chmod(run, 0o750) # 检查 tag 是否为合法文件名 if not re.match(FILE_NAME_REGEX, tag): raise ValueError(f"Invalid tag: {tag}.") - # 检查文件路径是否合法,防止路径遍历攻击 - if not file_path.startswith(os.path.abspath(run)): - raise ValueError(f"Invalid file path: {file_path}. Potential path traversal attack.\n") - # 基础路径校验 - if not GraphUtils.is_relative_to(file_path, safe_base_dir): - raise ValueError(f"Path out of bounds: {file_path}") - if os.path.islink(file_path): - raise RuntimeError("The target file is a symbolic link") - if os.path.islink(run): - raise RuntimeError(f"Parent directory contains a symbolic link") - # # 尝试写入文件 + # 构建文件路径并标准化 + file_path = os.path.join(run, tag) + # 目录安全校验 + success, error = GraphUtils.safe_check_save_file_path(run, True) + if not success: + raise PermissionError(error) + # 文件安全校验 + success, error = GraphUtils.safe_check_save_file_path(file_path) + if not success: + raise PermissionError(error) + # 尝试写入文件 with open(file_path, "w", encoding="utf-8") as file: json.dump(data, file, ensure_ascii=False, indent=4) os.chmod(file_path, 0o640) @@ -255,14 +246,13 @@ class GraphUtils: return None, 'Invalid data' except OSError as e: logger.error(f"Failed to create directory: {run}. Error: {e}\n") - return None, 'failed to create directory {run}' + return None, 'failed to create directory ' except Exception as e: logger.error(f'Error: File "{file_path}" is not accessible. Error: {e}') return None, 'failed to save file' @staticmethod def safe_load_data(run_name, tag, only_check=False): - runs = GraphState.get_global_value('runs', {}) run_dir = runs.get(str(run_name)) or run_name """Load a single .vis file from a given directory based on the tag.""" @@ -271,33 +261,14 @@ class GraphUtils: return None, error_message try: file_path = os.path.join(run_dir, tag) - file_path = os.path.normpath(file_path) # 标准化路径 - # 解析真实路径(包含符号链接跟踪) - real_path = os.path.realpath(file_path) - safe_base_dir = GraphState.get_global_value('logdir') - # 安全验证1:路径归属检查(防止越界访问) - if not GraphUtils.is_relative_to(file_path, safe_base_dir): - raise RuntimeError(f"Path out of bounds:") - # 安全验证2:禁止符号链接文件 - if os.path.islink(file_path): - raise RuntimeError(f"Detected symbolic link file") - if os.path.islink(run_dir): - raise RuntimeError(f"Parent directory contains a symbolic link") - # 安全验证3:二次文件类型检查(防御TOCTOU攻击) - if not os.path.isfile(real_path): - raise RuntimeError(f"Path is not a regular file") - # 安全检查4:文件存在性验证 - if not os.path.exists(real_path): - raise FileNotFoundError(f"File does not exist") - # 权限验证 - if not os.stat(real_path).st_mode & stat.S_IRUSR: - raise PermissionError(f"File has no read permissions") - # 文件大小验证 - if os.path.getsize(real_path) > MAX_FILE_SIZE: - file_size = GraphUtils.bytes_to_human_readable(os.path.getsize(real_path)) - max_size = GraphUtils.bytes_to_human_readable(MAX_FILE_SIZE) - raise RuntimeError( - f"File size exceeds limit ({file_size} > {max_size})") + # 目录安全校验 + success, error = GraphUtils.safe_check_load_file_path(run_dir, True) + if not success: + raise PermissionError(error) + # 文件安全校验 + success, error = GraphUtils.safe_check_load_file_path(file_path) + if not success: + raise PermissionError(error) # 读取文件比较耗时,支持onlyCheck参数,仅进行安全校验 if only_check: return True, None @@ -309,8 +280,103 @@ class GraphUtils: return None, "File is not a valid JSON file!" except Exception as e: logger.error(f'Error: File "{file_path}" is not accessible. Error: {e}') - return None, 'failed to load file' + return None, e + @staticmethod + def safe_check_save_file_path(file_path, is_dir=False): + file_path = os.path.normpath(file_path) # 标准化路径 + real_path = os.path.realpath(file_path) + safe_base_dir = GraphState.get_global_value('logdir') + try: + # 安全验证:路径长度检查 + if len(file_path) > FILE_PATH_MAX_LENGTH: + raise PermissionError(f"Path length exceeds limit") + # 安全验证:基础路径校验 + if not GraphUtils.is_relative_to(file_path, safe_base_dir): + raise ValueError(f"Path out of bounds: {file_path}") + if not is_dir and not os.path.exists(file_path): + return True, None + st = os.stat(file_path) + # 安全验证:禁止符号链接文件 + if os.path.islink(file_path): + raise PermissionError("The target file is a symbolic link") + # 安全验证:检查目录是否存在,如果不存在则创建 + if is_dir and not os.path.exists(real_path): + os.makedirs(real_path, exist_ok=True) + os.chmod(file_path, 0o640) + # 权限校验:检查是否有写权限 + if not os.stat(file_path).st_mode & stat.S_IWUSR: + raise PermissionError(f"No write permission for directory\n") + # 安全验证: 非windows系统下,属主检查 + if os.name != 'nt': + current_uid = os.getuid() + # 如果是root用户,跳过后续权限检查 + if current_uid == 0: + return True + # 属主检查 + if st.st_uid != current_uid: + raise PermissionError(f"Directory is not owned by the current user") + # group和其他用户不可写检查 + if st.st_mode & PERM_GROUP_WRITE or st.st_mode & PERM_OTHER_WRITE: + raise PermissionError(f"Directory has group or other write permission") + return True, None + except Exception as e: + logger.error(e) + return False, e + + @staticmethod + def safe_check_load_file_path(file_path, is_dir=False): + # 权限常量定义 + file_path = os.path.normpath(file_path) # 标准化路径 + real_path = os.path.realpath(file_path) + safe_base_dir = GraphState.get_global_value('logdir') + st = os.stat(real_path) + try: + # 安全验证:路径长度检查 + if len(real_path) > FILE_PATH_MAX_LENGTH: + raise PermissionError(f"Path length exceeds limit") + # 安全验证:路径归属检查(防止越界访问) + if not GraphUtils.is_relative_to(file_path, safe_base_dir): + raise PermissionError(f"Path out of bounds") + # 安全检查:文件存在性验证 + if not os.path.exists(real_path): + raise FileNotFoundError(f"File does not exist") + # 安全验证:禁止符号链接文件 + if os.path.islink(file_path): + raise PermissionError(f"Detected symbolic link file") + # 安全验证:文件类型检查(防御TOCTOU攻击) + # 文件类型 + if not is_dir and not os.path.isfile(real_path): + raise PermissionError(f"Path is not a regular file") + # 目录类型 + if is_dir and not Path(real_path).is_dir(): + raise PermissionError(f"Directory does not exist") + # 可读性检查 + if not st.st_mode & stat.S_IRUSR: + raise PermissionError(f"Directory lacks read permission for others") + # 文件大小校验 + if not is_dir and os.path.getsize(file_path) > MAX_FILE_SIZE: + file_size = GraphUtils.bytes_to_human_readable(os.path.getsize(file_path)) + max_size = GraphUtils.bytes_to_human_readable(MAX_FILE_SIZE) + raise PermissionError( + f"File size exceeds limit ({file_size} > {max_size})") + # 非windows系统下,属主检查 + if os.name != 'nt': + current_uid = os.getuid() + # 如果是root用户,跳过后续权限检查 + if current_uid == 0: + return True + # 属主检查 + if st.st_uid != current_uid: + raise PermissionError(f"Directory is not owned by the current user") + # group和其他用户不可写检查 + if st.st_mode & PERM_GROUP_WRITE or st.st_mode & PERM_OTHER_WRITE: + raise PermissionError(f"Directory has group or other write permission") + return True, None + except Exception as e: + logger.error(e) + return False, e + @staticmethod def find_config_files(run_name): """ @@ -322,28 +388,13 @@ class GraphUtils: run = runs.get(run_name) dir_path = Path(run) try: - file_path = os.path.normpath(run) # 标准化路径 - # 解析真实路径(包含符号链接跟踪) - real_path = os.path.realpath(file_path) - safe_base_dir = GraphState.get_global_value('logdir') - # 安全验证1:路径归属检查(防止越界访问) - if not GraphUtils.is_relative_to(file_path, safe_base_dir): - raise RuntimeError(f"Path out of bounds:") - # 安全验证2:禁止符号链接文件 - if os.path.islink(file_path): - raise RuntimeError(f"Detected symbolic link file") - # 安全检查3:文件存在性验证 - if not os.path.exists(real_path): - raise FileNotFoundError(f"File does not exist") - # 权限验证 - if not os.stat(real_path).st_mode & stat.S_IRUSR: - raise PermissionError(f"File has no read permissions") - if not dir_path.is_dir(): - raise ValueError(f"The provided path '{run}' is not a valid directory.") - return [ - file.name for file in dir_path.iterdir() - if file.is_file() and file.name.endswith('.vis.config') - ] + if GraphUtils.safe_check_load_file_path(run, True): + return [ + file.name for file in dir_path.iterdir() + if file.is_file() and file.name.endswith('.vis.config') + ] + else: + return [] except Exception as e: logger.error(e) return []