From cfb4bb8f0d423197dcc956a52be4fa91b5347d71 Mon Sep 17 00:00:00 2001 From: "Wu,Qiang-Roy" Date: Mon, 21 Jul 2025 19:56:53 +0800 Subject: [PATCH] fix security issue --- .../downloader/parallel_file_downloader.py | 7 ++- ascend_deployer/scripts/nexus.py | 4 +- ascend_deployer/start_deploy.py | 4 +- ascend_deployer/utils.py | 50 +++++++++++++++++++ roce-tool/roce_tool/config.py | 49 ++++++++++++++++++ test/test_utils.py | 22 ++++++++ 6 files changed, 132 insertions(+), 4 deletions(-) diff --git a/ascend_deployer/downloader/parallel_file_downloader.py b/ascend_deployer/downloader/parallel_file_downloader.py index 644f9cb4..f0d91ed2 100644 --- a/ascend_deployer/downloader/parallel_file_downloader.py +++ b/ascend_deployer/downloader/parallel_file_downloader.py @@ -160,7 +160,12 @@ def get_no_hash_result(file_info: DownloadFileInfo) -> CalcHashResult: class ParallelDownloader: - _MAX_DOWNLOAD_THREAD_NUM = int(os.environ.get("ASCEND_DEPLOYER_DOWNLOAD_MAX_SIZE", 16)) + __thread_num = ( + int(os.environ.get("ASCEND_DEPLOYER_DOWNLOAD_MAX_SIZE")) + if os.environ.get("ASCEND_DEPLOYER_DOWNLOAD_MAX_SIZE", "").isdigit() + else 16 + ) + _MAX_DOWNLOAD_THREAD_NUM = __thread_num if __thread_num > 0 else 16 _MAX_CALC_HASH_NUM = min(multiprocessing.cpu_count(), 32) def __init__(self, file_info_list: List[DownloadFileInfo], parent_instance=None): diff --git a/ascend_deployer/scripts/nexus.py b/ascend_deployer/scripts/nexus.py index ed0414f7..5b432d46 100644 --- a/ascend_deployer/scripts/nexus.py +++ b/ascend_deployer/scripts/nexus.py @@ -30,7 +30,7 @@ import time sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from utils import ROOT_PATH +from utils import ROOT_PATH, Validator try: from urllib.parse import urljoin @@ -59,6 +59,8 @@ class OsRepository: def __init__(self, ip=None, port=58081): try: self.nexus_run_ip = ip or os.environ["SSH_CONNECTION"].split()[2] + if not Validator().is_valid_ip(self.nexus_run_ip): + raise RuntimeError("nexus_run_ip is invalid, please check env variable SSH_CONNECTION") self.nexus_run_port = port self.working_on_ipv6 = False if ":" in self.nexus_run_ip: # ipv6格式需要用括号包住域名部分 diff --git a/ascend_deployer/start_deploy.py b/ascend_deployer/start_deploy.py index e2751453..6353fe66 100644 --- a/ascend_deployer/start_deploy.py +++ b/ascend_deployer/start_deploy.py @@ -87,8 +87,8 @@ class CLI(object): self.parser.add_argument("--install", dest="install", nargs="+", choices=utils.install_items, action=utils.ValidChoices, metavar="", help="Install specific package: %(choices)s") - self.parser.add_argument("--stdout_callback", dest="stdout_callback", - help="set display plugin, e.g. ansible_log") + self.parser.add_argument("--stdout_callback", dest="stdout_callback", choices=utils.stdout_callbacks, + help="set display plugin, e.g. default") self.parser.add_argument("--install-scene", dest="scene", nargs="?", choices=utils.scene_items, metavar="", help="Install specific scene: %(choices)s") self.parser.add_argument("--patch", dest="patch", nargs="+", choices=utils.patch_items, diff --git a/ascend_deployer/utils.py b/ascend_deployer/utils.py index 9955f168..b72acddf 100644 --- a/ascend_deployer/utils.py +++ b/ascend_deployer/utils.py @@ -16,6 +16,7 @@ # =========================================================================== import json import shlex +import socket import stat import argparse import getpass @@ -163,6 +164,9 @@ test_items = ['all', 'firmware', 'driver', 'nnrt', 'nnae', 'toolkit', 'toolbox', 'noded', 'clusterd', 'hccl-controller', 'ascend-operator', 'npu-exporter', 'resilience-controller', 'mindie_image', 'mcu'] check_items = ['full', 'fast'] +stdout_callbacks = ["default", "json", "yaml", "minimal", "dense", "oneline", + "community.general.yaml", "community.general.json", "null", + "ansible.builtin.default", "selective", "unixy", "debug"] LOG_MAX_BACKUP_COUNT = 5 LOG_MAX_SIZE = 20 * 1024 * 1024 @@ -304,3 +308,49 @@ def get_hosts_name(tags): if (isinstance(tags, str) and tags in dl_items) or (isinstance(tags, list) and set(tags) & set(dl_items)): return 'master,worker' return 'worker' + + +class Validator: + """ + This class is mainly to validate some value like ip address + + """ + + @staticmethod + def is_valid_ipv4(ip): + """ + return True if the ip is ipv4 else False + :param ip: the string of ip address + :return: bool, true if ipv4 otherwise false + """ + if not isinstance(ip, str): + return False + try: + socket.inet_pton(socket.AF_INET, ip) + return True + except (socket.error, ValueError, AttributeError): + return False + + @staticmethod + def is_valid_ipv6(ip): + """ + return True if the ip is ipv6 else False + :param ip: the string of ip address + :return: bool, true if ipv6 otherwise false + """ + try: + socket.inet_pton(socket.AF_INET6, ip) + return True + except (socket.error, ValueError, AttributeError): + return False + + def is_valid_ip(self, ip): + """ + :param ip: the string of ip address + :return: bool: true is validate otherwise false + """ + if not isinstance(ip, str): + return False + if ip.lower() == "localhost": + return True + return self.is_valid_ipv4(ip) or self.is_valid_ipv6(ip) diff --git a/roce-tool/roce_tool/config.py b/roce-tool/roce_tool/config.py index 5ed4e366..03e317fd 100644 --- a/roce-tool/roce_tool/config.py +++ b/roce-tool/roce_tool/config.py @@ -18,6 +18,7 @@ import getpass import logging import logging.handlers import os +import socket ROOT_PATH = os.path.dirname(__file__) WORK_PATH = os.path.expanduser("~/.roce_tool") @@ -26,6 +27,52 @@ LOG_MAX_SIZE = 20 * 1024 * 1024 LOG_FILE = os.path.join(WORK_PATH, "roce_tool.log") +class Validator: + """ + This class is mainly to validate some value like ip address + + """ + + @staticmethod + def is_valid_ipv4(ip): + """ + return True if the ip is ipv4 else False + :param ip: the string of ip address + :return: bool, true if ipv4 otherwise false + """ + if not isinstance(ip, str): + return False + try: + socket.inet_pton(socket.AF_INET, ip) + return True + except (socket.error, ValueError, AttributeError): + return False + + @staticmethod + def is_valid_ipv6(ip): + """ + return True if the ip is ipv6 else False + :param ip: the string of ip address + :return: bool, true if ipv6 otherwise false + """ + try: + socket.inet_pton(socket.AF_INET6, ip) + return True + except (socket.error, ValueError, AttributeError): + return False + + def is_valid_ip(self, ip): + """ + :param ip: the string of ip address + :return: bool: true is validate otherwise false + """ + if not isinstance(ip, str): + return False + if ip.lower() == "localhost": + return True + return self.is_valid_ipv4(ip) or self.is_valid_ipv6(ip) + + class UserHostFilter(logging.Filter): user = getpass.getuser() host = os.getenv("SSH_CLIENT", "localhost").split()[0] @@ -33,6 +80,8 @@ class UserHostFilter(logging.Filter): def filter(self, record): record.user = self.user record.host = self.host + if not Validator().is_valid_ip(self.host): + raise RuntimeError("wrong data of env virable SSH_CLIENT") return True diff --git a/test/test_utils.py b/test/test_utils.py index cf82dcb2..64815be6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,8 +1,15 @@ +import os import string +import sys import unittest import errno from unittest.mock import patch + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from ascend_deployer.utils import Validator from ascend_deployer.utils import get_validated_env + PATH_WHITE_LIST_LIN = string.digits + string.ascii_letters + '~-+_./ ' MAX_PATH_LEN = 4096 @@ -70,5 +77,20 @@ class TestGetValidatedEnv(unittest.TestCase): mock_lexists.side_effect = IOError(errno.ENOENT, 'No such file or directory') self.assertEqual(get_validated_env('TEST_ENV', check_symlink=True), '/path/to/file') + +class TestValidator(unittest.TestCase): + + def test_valid_ip(self): + validator = Validator() + self.assertEqual(True, validator.is_valid_ip("192.168.0.1")) + self.assertEqual(True, validator.is_valid_ip("0.0.0.0")) + self.assertEqual(True, validator.is_valid_ip("2001:0db8:85a3::8a2e:0370:7334")) + self.assertEqual(True, validator.is_valid_ip("::1")) + self.assertEqual(False, validator.is_valid_ip(" ")) + self.assertEqual(False, validator.is_valid_ip("11111111111")) + self.assertEqual(True, validator.is_valid_ip("localhost")) + self.assertEqual(True, validator.is_valid_ip("LOCALHOST")) + + if __name__ == '__main__': unittest.main() \ No newline at end of file -- Gitee