diff --git a/test/maple_test/compare.py b/test/maple_test/compare.py index c54d9365c8a5c12171ddfabe9917db694ad6bbe8..e6f41e0f57c4fbcf53c1a469f679d09e2964ca9a 100644 --- a/test/maple_test/compare.py +++ b/test/maple_test/compare.py @@ -20,10 +20,20 @@ import logging import re import sys from textwrap import indent +from functools import partial from utils import complete_path, read_file from utils import split_comment, filter_line +ASSERT_FLAG = "ASSERT" + +SCAN_KEYWORDS = ["auto", "not", "next", "end"] +CMP_KEYWORDS = ["end", "not", "next", "full"] + + +class CompareError(Exception): + pass + def main(): opts = parse_cli() @@ -32,8 +42,15 @@ def main(): compare_object = opts.compare_object assert_flags = opts.assert_flag if not assert_flags: - assert_flags.append("ASSERT") + assert_flags.append(ASSERT_FLAG) content = compare_object.read() + lines = content.splitlines(True) + line_map = [] + start = 0 + for line in lines: + length = len(line) + line_map.append((start, start + length, line)) + start += length print("compare.py input:") print(indent(content, "\t", lambda line: True)) @@ -43,7 +60,8 @@ def main(): sys.stderr.write("ERROR: require compare objects, filepath or stdin \n") sys.exit(253) comment_lines = split_comment(comment, read_file(case_path))[1] - result = True + compare_result = True + print("Starting Match:") for assert_flag in assert_flags: assert_lines = [ line for line in comment_lines if filter_line(line, assert_flag) @@ -53,76 +71,92 @@ def main(): "ASSERT flag: {}, No regex find, " "make sure you write the assert line".format(assert_flag) ) - passed = True + match_pass = True + start = 0 for assert_line in assert_lines: - found = False - flag, pattern = extract_pattern(assert_line, assert_flag) - if pattern is None: - logging.error( - "ASSERT flag: {}, Failed Reason: " - "Not found valid match pattern".format(assert_flag) - ) - sys.exit(1) - try: - re.compile(pattern) - except re.error: - passed = False - logging.error( - "ASSERT flag: {}, Failed Reason: " - "Error pattern: {!r}".format(assert_flag, pattern) - ) + + pattern_flag, pattern = extract_pattern(assert_line, assert_flag) + if not is_valid_pattern(pattern): + match_pass = False break - print( - "ASSERT flag: {}, Match regex: {}, " - "is not-scan: {}".format(assert_flag, pattern, not flag) - ) - if pattern: - found = is_match(pattern, content) - if found == flag: - passed = True + + keywords = pattern_flag.split("-") + valid_keywords = [] + assert_mode = keywords[0] + match_func = None + if assert_mode == "scan": + match_func = regex_match + valid_keywords = SCAN_KEYWORDS + elif assert_mode == "cmp": + match_func = cmp_match + valid_keywords = CMP_KEYWORDS else: - passed = False - logging.error( - "Failed Reason: regex: {}, is not-scan: {}, Matched: {}, exit".format( - pattern, not flag, found + raise CompareError("scan mode: {} is not valid".format(assert_mode)) + for keyword in keywords[1:]: + if keyword not in valid_keywords: + raise CompareError( + "keyword: {} is not valid for {}".format(keyword, assert_mode) ) + if keyword == "auto": + match_func = partial(auto_regex_match, match_func=match_func) + elif keyword == "not": + match_func = partial(not_match, match_func=match_func) + elif keyword == "next": + match_func = partial(next_match, match_func=match_func) + elif keyword == "end": + match_func = end_match + elif keyword == "full": + match_func = full_match + if "next" not in keywords and "end" not in keywords: + start = 0 + result, start = match_func(content, line_map, pattern, start) + print( + " assert line: '{}', result: {}".format( + assert_line.split(":")[-1].strip(), result ) - break - if passed is False: + ) + match_pass &= result + compare_result &= match_pass + if match_pass is False: + print("Match End:") + print( + "ASSERT flag: {}, Compare Failed: {}".format( + " ".join(assert_flags), compare_result + ) + ) result = False - if result is True: + if compare_result is True: + print("Match End !!!") print( - "ASSERT flag: {}, Compare Passed: {}".format(" ".join(assert_flags), result) + "ASSERT flag: {}, Compare Passed: {}".format( + " ".join(assert_flags), compare_result + ) ) return 0 + sys.exit(1) def extract_pattern(line, flag): - line_flag = line.strip().split(":")[0].strip() - if line_flag != flag: - return None, None - - line = line.strip()[len(line_flag) + 1 :].strip().lstrip(":").strip() - if line[:5] == "scan ": - words = line[4:].strip().split() - pattern = r"\s*".join([word.strip() for word in words]) - return True, pattern - if line[:9] == "scan-not ": - words = line[8:].strip().split() - pattern = r"\s*".join([word.strip() for word in words]) - return False, pattern - if line[:10] == "scan-auto ": - words = line[9:].strip().split() - pattern = r"\s*".join([re.escape(word.strip()) for word in words]) - return True, pattern - return None, None - - -def is_match(pattern, test_str): - if re.findall(pattern, test_str, re.MULTILINE): + line_flag, pattern_line = line.lstrip().split(":", 1) + if line_flag.strip() != flag: + raise CompareError("Error: {} = {}".format(line_flag, flag)) + try: + pattern_flag, raw_pattern = pattern_line.lstrip().split(" ", 1) + except ValueError: + pattern_flag = pattern_line.lstrip() + raw_pattern = None + return pattern_flag, raw_pattern + + +def is_valid_pattern(pattern): + try: + re.compile(pattern) + except re.error: + logging.error("Error pattern: {!r}".format(pattern)) + return False + finally: return True - return False def parse_cli(): @@ -138,7 +172,7 @@ def parse_cli(): "case_path", type=complete_path, help="Source path: read compare rules" ) parser.add_argument( - "compare_object", + "--compare_object", nargs="?", type=argparse.FileType("r"), default=sys.stdin, @@ -148,6 +182,63 @@ def parse_cli(): return opts +def regex_match(content, line_map, pattern, start=0): + matches = re.finditer(str(pattern), content, re.MULTILINE) + end = 0 + for _, match in enumerate(matches, start=1): + end = match.end() + line_num = text_index_to_line_num(line_map, end) + if line_num + 1 >= len(line_map): + return True, end + return True, line_map[line_num + 1][0] + return False, start + + +def cmp_match(content, line_map, pattern, start=0): + line_num = text_index_to_line_num(line_map, start) + line = content.splitlines()[line_num] + if line == pattern: + return True, line_map[line_num][1] + else: + return False, start + + +def auto_regex_match(content, line_map, pattern, start=0, match_func=regex_match): + pattern = r"\s+".join([re.escape(word.strip()) for word in pattern.strip().split()]) + return match_func(content, line_map, pattern, start) + + +def not_match(content, line_map, pattern, start=0, match_func=regex_match): + result, end = match_func(content, line_map, pattern, start) + return not result, end + + +def next_match(content, line_map, pattern, start=0, match_func=regex_match): + return match_func(content, line_map, pattern, start) + + +def end_match(content, line_map, pattern, start=0, match_func=regex_match): + line_num = text_index_to_line_num(line_map, start) + if line_num < len(line_map): + return False, start + return True, start + + +def full_match(content, line_map, pattern, start=0, match_func=regex_match): + pattern = pattern.encode("utf-8").decode("unicode_escape") + if content != pattern: + return False, start + return True, start + + +def text_index_to_line_num(line_map, index): + for line_num, line in enumerate(line_map): + start, end, _ = line + if start <= index < end: + return line_num + return line_num + 1 + + if __name__ == "__main__": logging.basicConfig( format="\t%(message)s", level=logging.DEBUG, stream=sys.stderr, diff --git a/test/maple_test/task.py b/test/maple_test/task.py index 5ecbc6731de65f16375fb571c9599533700a3ea2..857b4d4814e594992ac056ce29e1ff42088b7403 100644 --- a/test/maple_test/task.py +++ b/test/maple_test/task.py @@ -177,12 +177,13 @@ class TestSuiteTask: return testlist def _search_list(self, base_dir, testlist, encoding): + logger = configs.LOGGER suffixes = self.suffix_comments.keys() include, exclude = testlist case_files = set() cases = [] case_files = self._search_case(include, exclude, base_dir, suffixes) - if self.path.is_file() and self.path in case_files: + if self.path.is_file(): case_files = [self.path] else: case_files = [ @@ -190,6 +191,11 @@ class TestSuiteTask: for file in case_files if is_relative(file, self.path) ] + if not case_files: + logger.info( + "Path %s not in testlist, be sure add path to testlist", + str(self.path), + ) for case_file in case_files: case_name = str(case_file).replace(".", "_") comment = self.suffix_comments[case_file.suffix[1:]]