diff --git a/ascend_deployer/downloader/download_util.py b/ascend_deployer/downloader/download_util.py index 561fdbc5b05e229def8a3f6022f82a6a6b60c495..cb7a2b5a544232dfba25a226aadba37deea72104 100644 --- a/ascend_deployer/downloader/download_util.py +++ b/ascend_deployer/downloader/download_util.py @@ -30,7 +30,9 @@ from pathlib import PurePath from urllib import request from urllib.error import ContentTooShortError, URLError from typing import Optional - +base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.insert(0, base_dir) +from ascend_deployer.module_utils.path_manager import get_validated_env from . import logger_config REFERER = "https://www.hiascend.com/" @@ -89,9 +91,9 @@ def get_download_path(): return cur if platform.system() == 'Linux': - deployer_home = os.getenv('HOME') - if os.getenv('ASCEND_DEPLOYER_HOME') is not None: - deployer_home = os.getenv('ASCEND_DEPLOYER_HOME') + deployer_home = get_validated_env('HOME') + if get_validated_env('ASCEND_DEPLOYER_HOME') is not None: + deployer_home = get_validated_env('ASCEND_DEPLOYER_HOME') else: deployer_home = os.getcwd() diff --git a/ascend_deployer/downloader/logger_config.py b/ascend_deployer/downloader/logger_config.py index 7544e387047ee9d7572de61520b8f4aa77ab5ed2..641c052d9530ccb603d9ba9b96bcb1180edbd6af 100644 --- a/ascend_deployer/downloader/logger_config.py +++ b/ascend_deployer/downloader/logger_config.py @@ -20,12 +20,18 @@ import logging.handlers import os import platform import stat +import sys + +base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.insert(0, base_dir) +from ascend_deployer.module_utils.path_manager import get_validated_env class RotatingFileHandler(logging.handlers.RotatingFileHandler): """ rewrite RotatingFileHandler, assign permissions to downloader.log and downloader.log.* """ + def doRollover(self): largest_backfile = "{}.{}".format(self.baseFilename, self.backupCount) if os.path.exists(largest_backfile): @@ -48,9 +54,9 @@ class BasicLogConfig(object): else: deployer_home = '' if platform.system() == 'Linux': - deployer_home = os.getenv('HOME') - if os.getenv('ASCEND_DEPLOYER_HOME') is not None: - deployer_home = os.getenv('ASCEND_DEPLOYER_HOME') + deployer_home = get_validated_env('HOME') + if get_validated_env('ASCEND_DEPLOYER_HOME') is not None: + deployer_home = get_validated_env('ASCEND_DEPLOYER_HOME') else: deployer_home = os.getcwd() parent_dir = os.path.join(deployer_home, 'ascend-deployer') @@ -71,15 +77,15 @@ class BasicLogConfig(object): os.chmod(LOG_FILE_OPERATION, stat.S_IRUSR | stat.S_IWUSR) USER_NAME = getpass.getuser() - CLIENT_IP = os.getenv('SSH_CLIENT', 'localhost').split()[0] + CLIENT_IP = (get_validated_env('SSH_CLIENT') or 'localhost').split()[0] EXTRA = {'user_name': USER_NAME, 'client_ip': CLIENT_IP} LOG_DATE_FORMAT = '%Y-%m-%d %H:%M:%S' LOG_FORMAT_STRING = \ - "%(asctime)s downloader [%(levelname)s] " \ - "[%(filename)s:%(lineno)d %(funcName)s] %(message)s" + "%(asctime)s downloader [%(levelname)s] " \ + "[%(filename)s:%(lineno)d %(funcName)s] %(message)s" LOG_FORMAT_STRING_OPERATION = \ - "%(asctime)s localhost [%(levelname)s] " \ - "[%(filename)s:%(lineno)d %(funcName)s] %(message)s" + "%(asctime)s localhost [%(levelname)s] " \ + "[%(filename)s:%(lineno)d %(funcName)s] %(message)s" LOG_LEVEL = logging.INFO ROTATING_CONF = dict( @@ -88,6 +94,7 @@ class BasicLogConfig(object): backupCount=5, encoding="UTF-8") + LOG_CONF = BasicLogConfig() diff --git a/ascend_deployer/install.sh b/ascend_deployer/install.sh index db15292faf6d531761c015f191470f627c4523be..0c3b8e71995adb937e2149ef33f766ab4078873c 100644 --- a/ascend_deployer/install.sh +++ b/ascend_deployer/install.sh @@ -10,9 +10,9 @@ main() { fi python3 -V > /dev/null 2>&1 if [[ $? != 0 ]]; then - python ${BASE_DIR}/ascend_deployer.py $* + python ${BASE_DIR}/start_deploy.py $* else - python3 ${BASE_DIR}/ascend_deployer.py $* + python3 ${BASE_DIR}/start_deploy.py $* fi } diff --git a/ascend_deployer/large_scale_deploy/tools/log_tool.py b/ascend_deployer/large_scale_deploy/tools/log_tool.py index 2fe6e8b0ec7a3a67e499ac5d754c5f1e4a1e8838..0e2918dca7f4f8e62e9c7388e25d6008cf7d2a48 100644 --- a/ascend_deployer/large_scale_deploy/tools/log_tool.py +++ b/ascend_deployer/large_scale_deploy/tools/log_tool.py @@ -4,6 +4,10 @@ import os import stat import sys +base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +sys.path.insert(0, base_dir) +from ascend_deployer.module_utils.path_manager import get_validated_env + CUR_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -14,8 +18,8 @@ class LogTool: if 'site-packages' not in CUR_DIR and 'dist-packages' not in CUR_DIR: log_dir = os.path.dirname(CUR_DIR) else: - if os.getenv('ASCEND_DEPLOYER_HOME'): - deployer_home = os.getenv('ASCEND_DEPLOYER_HOME') + if get_validated_env('ASCEND_DEPLOYER_HOME'): + deployer_home = get_validated_env('ASCEND_DEPLOYER_HOME') else: deployer_home = os.getcwd() log_dir = os.path.join(deployer_home, "ascend-deployer") diff --git a/ascend_deployer/module_utils/check_library_utils/npu_checks.py b/ascend_deployer/module_utils/check_library_utils/npu_checks.py index b931959b4eaeadcf67bf67ceba391ca5bb5596be..d0f83876fce9cdd4cf990e422953ca0d4470210e 100644 --- a/ascend_deployer/module_utils/check_library_utils/npu_checks.py +++ b/ascend_deployer/module_utils/check_library_utils/npu_checks.py @@ -198,4 +198,4 @@ class NPUCheck: if status != "|" and status != "OK": util.record_error("[ASCEND][[ERROR]] Critical issue with NPU, please check the health of card.", self.error_messages) - return + return \ No newline at end of file diff --git a/ascend_deployer/module_utils/path_manager.py b/ascend_deployer/module_utils/path_manager.py index c16b3a8036693b1a40270d884b2d9d476c874d4c..32b7ce6fb5ee98137c1304b28d3015ccd38b0b56 100644 --- a/ascend_deployer/module_utils/path_manager.py +++ b/ascend_deployer/module_utils/path_manager.py @@ -1,7 +1,81 @@ +import errno import os.path import shutil +import string _CUR_DIR = os.path.dirname(__file__) +PATH_WHITE_LIST_LIN = string.digits + string.ascii_letters + '~-+_./ ' +MAX_PATH_LEN = 4096 + + +def get_validated_env( + env_name, + whitelist=PATH_WHITE_LIST_LIN, + min_length=1, + max_length=MAX_PATH_LEN, + check_symlink=True +): + """ + 获取并验证环境变量 (兼容 Python 2/3) + :param env_name: 环境变量名称 + :param whitelist: 允许的值列表 + :param min_length: 最小长度限制 + :param max_length: 最大长度限制 + :param check_symlink: 是否检查软链接 + :return: 验证通过的环境变量值 + :raises ValueError: 验证失败时抛出 + """ + value = os.getenv(env_name) + + if value is None: + return None + + # 白名单校验 + for char in value: + if char not in whitelist: + raise ValueError( + "The path is invalid. The path can contain only char in '{}'".format(whitelist)) + + # 长度校验 + str_len = len(value) + if min_length is not None and str_len < min_length: + raise ValueError( + "Value for {} is too short. Minimum length: {}, actual: {}".format( + env_name, min_length, str_len + ) + ) + + if max_length is not None and str_len > max_length: + raise ValueError( + "Value for {} is too long. Maximum length: {}, actual: {}".format( + env_name, max_length, str_len + ) + ) + + # 路径安全校验 + if check_symlink: + # 在 Python 2/3 中正确处理 unicode 路径 + if isinstance(value, bytes): + path_value = value.decode('utf-8', 'replace') + else: + path_value = value + # 软链接检查 + if check_symlink: + try: + # 检查路径是否存在且是符号链接 + if os.path.lexists(path_value) and os.path.islink(path_value): + raise ValueError( + "Path for {} is a symlink: {}. Symlinks are not allowed for security reasons.".format( + env_name, path_value + ) + ) + except (OSError, IOError) as e: + # 处理文件系统访问错误 + if e.errno != errno.ENOENT: # 忽略文件不存在的错误 + raise ValueError( + "Error checking symlink for {}: {} - {}".format(env_name, path_value, str(e)) + ) + return value class ProjectPath: diff --git a/ascend_deployer/ascend_deployer.py b/ascend_deployer/start_deploy.py similarity index 100% rename from ascend_deployer/ascend_deployer.py rename to ascend_deployer/start_deploy.py diff --git a/ascend_deployer/utils.py b/ascend_deployer/utils.py index f6938a531c469af172f4f8a621ad42b9d3440aa1..9955f168622ae647c1dcf07ba32a64e451d69688 100644 --- a/ascend_deployer/utils.py +++ b/ascend_deployer/utils.py @@ -17,7 +17,6 @@ import json import shlex import stat -import sys import argparse import getpass import logging @@ -26,8 +25,9 @@ import platform import shutil import re import os +import sys from subprocess import PIPE, Popen - +from module_utils.path_manager import get_validated_env ROOT_PATH = SRC_PATH = os.path.dirname(__file__) NEXUS_SENTINEL_FILE = os.path.expanduser('~/.local/nexus.sentinel') MODE_700 = stat.S_IRWXU @@ -85,7 +85,7 @@ def copy_scripts(): if 'site-packages' in ROOT_PATH or 'dist-packages' in ROOT_PATH: deployer_home = os.getcwd() if platform.system() == 'Linux': - deployer_home = os.getenv('ASCEND_DEPLOYER_HOME', os.getenv('HOME')) + deployer_home = get_validated_env('ASCEND_DEPLOYER_HOME') or get_validated_env('HOME') ROOT_PATH = os.path.join(deployer_home, 'ascend-deployer') copy_scripts() @@ -172,7 +172,7 @@ LOG_OPERATION_FILE = os.path.join(ROOT_PATH, 'install_operation.log') class UserHostFilter(logging.Filter): user = getpass.getuser() - host = os.getenv('SSH_CLIENT', 'localhost').split()[0] + host = (get_validated_env('SSH_CLIENT') or 'localhost').split()[0] def filter(self, record): record.user = self.user diff --git a/setup.py b/setup.py index 001eb2addff12c63c97e20fac6c212aed8b6b3bf..6e540c5a8693d278ba37252a962d0a3f69527477 100644 --- a/setup.py +++ b/setup.py @@ -77,7 +77,7 @@ setuptools.setup( entry_points={ # Optional 'console_scripts': [ 'ascend-download=ascend_deployer.ascend_download:main', - 'ascend-deployer=ascend_deployer.ascend_deployer:main', + 'ascend-deployer=ascend_deployer.start_deploy:main', 'large-scale-deployer=ascend_deployer.large_scale_deployer:main' ] }, diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf82dcb216946b47a673af9acb22a12dd6d5662e --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,74 @@ +import string +import unittest +import errno +from unittest.mock import patch +from ascend_deployer.utils import get_validated_env +PATH_WHITE_LIST_LIN = string.digits + string.ascii_letters + '~-+_./ ' +MAX_PATH_LEN = 4096 + +class TestGetValidatedEnv(unittest.TestCase): + + @patch('os.getenv') + def test_env_not_set(self, mock_getenv): + mock_getenv.return_value = None + with self.assertRaises(ValueError) as context: + get_validated_env('TEST_ENV') + self.assertEqual(str(context.exception), "Environment variable TEST_ENV is not set") + + @patch('os.getenv') + def test_env_value_not_in_whitelist(self, mock_getenv): + mock_getenv.return_value = 'invalid_value(' + with self.assertRaises(ValueError) as context: + get_validated_env('TEST_ENV') + self.assertEqual(str(context.exception), "The path is invalid. The path can contain only char in '{}'".format(PATH_WHITE_LIST_LIN)) + + @patch('os.getenv') + def test_env_value_too_short(self, mock_getenv): + mock_getenv.return_value = 'a' + with self.assertRaises(ValueError) as context: + get_validated_env('TEST_ENV', min_length=2) + self.assertEqual(str(context.exception), "Value for TEST_ENV is too short. Minimum length: 2, actual: 1") + + @patch('os.getenv') + def test_env_value_too_long(self, mock_getenv): + mock_getenv.return_value = 'a' * (MAX_PATH_LEN + 1) + with self.assertRaises(ValueError) as context: + get_validated_env('TEST_ENV', max_length=MAX_PATH_LEN) + self.assertEqual(str(context.exception), "Value for TEST_ENV is too long. Maximum length: {}, actual: {}".format(MAX_PATH_LEN, MAX_PATH_LEN + 1)) + + @patch('os.getenv') + @patch('os.path.lexists') + @patch('os.path.islink') + def test_env_value_is_symlink(self, mock_islink, mock_lexists, mock_getenv): + mock_getenv.return_value = '/path/to/symlink' + mock_lexists.return_value = True + mock_islink.return_value = True + with self.assertRaises(ValueError) as context: + get_validated_env('TEST_ENV', check_symlink=True) + self.assertEqual(str(context.exception), "Path for TEST_ENV is a symlink: /path/to/symlink. Symlinks are not allowed for security reasons.") + + @patch('os.getenv') + @patch('os.path.lexists') + @patch('os.path.islink') + def test_env_value_is_not_symlink(self, mock_islink, mock_lexists, mock_getenv): + mock_getenv.return_value = '/path/to/file' + mock_lexists.return_value = True + mock_islink.return_value = False + self.assertEqual(get_validated_env('TEST_ENV', check_symlink=True), '/path/to/file') + + @patch('os.getenv') + @patch('os.path.lexists') + def test_env_value_lexists_error(self, mock_lexists, mock_getenv): + mock_getenv.return_value = '/path/to/file' + mock_lexists.side_effect = OSError(errno.ENOENT, 'No such file or directory') + self.assertEqual(get_validated_env('TEST_ENV', check_symlink=True), '/path/to/file') + + @patch('os.getenv') + @patch('os.path.lexists') + def test_env_value_lexists_io_error(self, mock_lexists, mock_getenv): + mock_getenv.return_value = '/path/to/file' + mock_lexists.side_effect = IOError(errno.ENOENT, 'No such file or directory') + self.assertEqual(get_validated_env('TEST_ENV', check_symlink=True), '/path/to/file') + +if __name__ == '__main__': + unittest.main() \ No newline at end of file