diff --git a/ascend_deployer/large_scale_deploy/tools/str_tool.py b/ascend_deployer/large_scale_deploy/tools/str_tool.py index 35f73b5e86cd24a124a0e77f1dd1d1247b0be5aa..2023fcc3985a62949c1d022c7787f398366a4f02 100644 --- a/ascend_deployer/large_scale_deploy/tools/str_tool.py +++ b/ascend_deployer/large_scale_deploy/tools/str_tool.py @@ -3,8 +3,9 @@ import re class StrTool: - _NON_WORD_PATTERN = re.compile(r"[^a-zA-Z0-9]") + _FURMULA_PATTERN = r'^[\w\s\.\+\-\*\/\(\)\'"]+$' + _EXCEPTION = ["()"] _SAFE_EVAL_SCOPE = { '__builtins__': None, 'int': int, @@ -12,9 +13,14 @@ class StrTool: } @classmethod - def to_py_field(cls, src_field: str): + def to_py_field(cls, src_field): return cls._NON_WORD_PATTERN.sub("_", src_field) @classmethod - def safe_eval(cls, expr: str): - return str(eval(expr, cls._SAFE_EVAL_SCOPE)) \ No newline at end of file + def safe_eval(cls, expr): + if not re.fullmatch(cls._FURMULA_PATTERN, expr): + raise ValueError("unsafe expression: {}".format(expr)) + for k in cls._EXCEPTION: + if k in expr: + raise ValueError("unsafe expression: {}".format(expr)) + return str(eval(expr, cls._SAFE_EVAL_SCOPE)) diff --git a/ascend_deployer/module_utils/inventory_file.py b/ascend_deployer/module_utils/inventory_file.py index 92618f6bb72aa7580bafdd20b33dc91ae23ceadf..54f0083cd92874035ebac4e81a2bd33b086a6c05 100644 --- a/ascend_deployer/module_utils/inventory_file.py +++ b/ascend_deployer/module_utils/inventory_file.py @@ -33,6 +33,8 @@ class Mark: class StrTool: _NON_WORD_PATTERN = re.compile(r"[^a-zA-Z0-9]") + _FURMULA_PATTERN = r'^[\w\s\.\+\-\*\/\(\)\'"]+$' + _EXCEPTION = ["()"] _SAFE_EVAL_SCOPE = { '__builtins__': None, 'int': int, @@ -45,6 +47,11 @@ class StrTool: @classmethod def safe_eval(cls, expr): + if not re.fullmatch(cls._FURMULA_PATTERN, expr): + raise ValueError("unsafe expression: {}".format(expr)) + for k in cls._EXCEPTION: + if k in expr: + raise ValueError("unsafe expression: {}".format(expr)) return str(eval(expr, cls._SAFE_EVAL_SCOPE)) diff --git a/test/module_utils_test/test_inventory_file.py b/test/module_utils_test/test_inventory_file.py index 073880fb8e001fed4c6ccdc4ae4001272a201272..171d583f8ee401a784d10959e33b0162be30ba8f 100644 --- a/test/module_utils_test/test_inventory_file.py +++ b/test/module_utils_test/test_inventory_file.py @@ -1,8 +1,10 @@ +import os +import sys import unittest -from ctypes import oledll from unittest.mock import patch - -from ascend_deployer.module_utils.inventory_file import Var, Host, HostParams, InventoryFile, ConfigrationError, IPRange +from ascend_deployer.module_utils.inventory_file import ( + Var, Host, HostParams, InventoryFile, ConfigrationError, IPRange, StrTool + ) class TestVar(unittest.TestCase): @@ -121,8 +123,8 @@ class TestHostParam(unittest.TestCase): ['10.10.10.1 set_hostname="master-1"', '10.10.10.2 set_hostname="master-2"'] ), ( - 'local_ip_port="{ip}:8080"', - ['10.10.10.1 local_ip_port="10.10.10.1:8080"', '10.10.10.2 local_ip_port="10.10.10.2:8080"'] + 'set_hostname="master-{str(index+1)+\'x\'}"', + ['10.10.10.1 set_hostname="master-2x"', '10.10.10.2 set_hostname="master-3x"'] ), ( 'local_ip_port="{ip}-{index}:8080"', @@ -149,3 +151,31 @@ class TestInventoryFile(unittest.TestCase): def test_get_parsed_inventory_file_path(self): self.inventory_file.get_parsed_inventory_file_path() + + +class TestStrTool(unittest.TestCase): + + @patch("sys.exit") + def test_safe_eval(self, mock_exit): + test_cases = [ + ("'master-'+str(1+20)+'x'", "master-21x"), + ("'master-'+'d'+'d'+str(1+1+1+1)", "master-dd4"), + ("'master.'+str(1+1)", "master.2") + ] + for param, expect in test_cases: + self.assertEqual(expect, StrTool.safe_eval(param)) + + error_cases = [ + "__import__('os')", + "'None'[0]", + "open('/etc')", + "'ddd'.upper()" + "'master-'+{open('/etc/passwd')}", + ] + for case in error_cases: + with self.assertRaises(Exception): + StrTool.safe_eval(case) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index 64815be61a239bf8ecc227abb9e689c48ae493cc..c610a309b7738896abf0739daee1a1e373b620d0 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -5,14 +5,13 @@ import unittest import errno from unittest.mock import patch -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - from ascend_deployer.utils import Validator from ascend_deployer.utils import get_validated_env PATH_WHITE_LIST_LIN = string.digits + string.ascii_letters + '~-+_./ ' MAX_PATH_LEN = 4096 + class TestGetValidatedEnv(unittest.TestCase): @patch('os.getenv')