diff --git a/test/maple_test/compare.py b/test/maple_test/compare.py index a5f1d92b390b324b2276b43b4fe6be68324b52ee..fa7e40a13d15e73e12971703b7edf7519d070eaa 100644 --- a/test/maple_test/compare.py +++ b/test/maple_test/compare.py @@ -23,7 +23,7 @@ from textwrap import indent from functools import partial from utils import complete_path, read_file -from utils import split_comment, filter_line +from utils import split_comment, filter_line, escape ASSERT_FLAG = "ASSERT" EXPECTED_FLAG = "EXPECTED" @@ -239,7 +239,7 @@ def gen_compare_regex(comment, assert_flags, expected_flag): regex = "" for flag in expected_flag: excepted_regex = r"(?:{comment}\s*)({flag}[\t ]*\:[\t ]*.*$)".format( - comment=comment, flag=flag + comment=escape("\\$()*+.[]?^{}|", comment), flag=flag ) if regex != "": regex = "{}|{}".format(regex, excepted_regex) @@ -247,7 +247,7 @@ def gen_compare_regex(comment, assert_flags, expected_flag): regex = excepted_regex for flag in assert_flags: assert_regex = r"(?:^[\t ]*{comment}\s*)({flag}[\t ]*\:[\t ]*.*$)".format( - comment=comment, flag=flag + comment=escape("\\$()*+.[]?^{}|", comment), flag=flag ) if regex != "": regex = "{}|{}".format(regex, assert_regex) diff --git a/test/maple_test/task.py b/test/maple_test/task.py index 05a6ac8efb4f753a0318a4eaa33eb8a455e8ea4b..45e1fa168e4fda780691aed6db2b2417f8085c3e 100644 --- a/test/maple_test/task.py +++ b/test/maple_test/task.py @@ -43,6 +43,7 @@ from maple_test.utils import ( ls_all, complete_path, is_relative, + quote, ) @@ -544,7 +545,7 @@ class SingleTask: for command in case.commands: command = self._form_line(command, config) compare_cmd = " {} {} --comment={} ".format( - EXECUTABLE, COMPARE, shlex.quote(case.comment) + EXECUTABLE, COMPARE, quote(case.comment) ) self.commands.append(format_compare_command(command, compare_cmd)) diff --git a/test/maple_test/utils.py b/test/maple_test/utils.py index fa27c7f14a173798e03acd62f29a59ff3a5fbc58..f4d7c1b4be5da517002d644161f0d7657449cf4f 100644 --- a/test/maple_test/utils.py +++ b/test/maple_test/utils.py @@ -20,6 +20,9 @@ import locale import os import sys import timeit +import re +import platform +import shlex from functools import wraps from pathlib import Path @@ -198,3 +201,23 @@ def merge_result(multi_results): if result in UNSUCCESSFUL: return result return PASS + + +def escape(special_chars, original_string): + special_re = re.compile( + "(" + "|".join(re.escape(char) for char in list(special_chars)) + ")" + ) + special_map = {char: "\\%s" % char for char in special_chars} + + def escape_special_char(m): + char = m.group(1) + return special_map[char] + + return special_re.sub(escape_special_char, original_string) + + +def quote(original_string): + if platform.system() == "Windows": + return '"' + escape('\\"', original_string) + '"' + else: + return shlex.quote(original_string)