From c5fd19251140886a7dffd5709577ae3a17ab00e6 Mon Sep 17 00:00:00 2001 From: i-robot Date: Tue, 5 Aug 2025 08:41:53 +0000 Subject: [PATCH] =?UTF-8?q?!4994=20=E3=80=90=E4=BC=98=E5=8C=96=E3=80=91ms?= =?UTF-8?q?=20compare=20api=5Fmapping=E4=B8=ADapi=E5=90=8D=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91=E7=9A=84=E4=BF=AE=E6=94=B9=20Merge?= =?UTF-8?q?=20pull=20request=20!4994=20from=20yinglinwei/master?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/core/compare/acc_compare.py | 22 ++++++++++++------- .../test/core_ut/compare/test_acc_compare.py | 12 +++++----- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index a2a635740c..77d6d6f91e 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -359,13 +359,17 @@ class ProcessDf: return npu_op_name def modify_compare_data_with_user_mapping(self, npu_df, bench_df): + def remove_prefix(string, prefix): + if string.startswith(prefix): + return string[len(prefix):] + return string + def gen_input_compare_key(pattern, term): is_unmatched = True for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')): - if op_name.split(pattern)[1].startswith(str(prefix)): + if remove_prefix(op_name, api_origin_name + pattern) == str(prefix): npu_df.loc[index, CompareConst.CMP_KEY] = ( - op_name.replace(pattern + str(prefix), - pattern + str(mapping_dict.get(f'pt_{term}')[i]))) + op_name.replace(pattern + str(prefix), pattern + str(mapping_dict.get(f'pt_{term}')[i]))) is_unmatched = False return is_unmatched @@ -387,15 +391,17 @@ class ProcessDf: continue for index in ms_api_indices_dict.get(ms_api): op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1) - if CompareConst.INPUT_PATTERN in op_name: + state = npu_df.loc[index, Const.STATE] + api_origin_name = npu_df.loc[index, Const.API_ORIGIN_NAME].replace(ms_api, pt_api, 1) + if state == Const.INPUT: is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args') - elif CompareConst.KWARGS_PATTERN in op_name: + elif state == Const.KWARGS: is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args') - elif CompareConst.OUTPUT_PATTERN in op_name: + elif state == Const.OUTPUT: is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output') - elif CompareConst.PARAMS_PATTERN in op_name: + elif state == Const.PARAMS: is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters') - elif CompareConst.PARAMS_GRAD_PATTERN in op_name: + elif state == Const.PARAMS_GRAD: is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad') else: logger.error(f'Excepted op_name: {op_name}') diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py index 451eb9badf..a0d6aeee05 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py @@ -611,13 +611,13 @@ class TestProcessDf(unittest.TestCase): }] npu_df = pd.DataFrame([ - ['Functional.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.conv2d.0.forward.input.0'], - ['Functional.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.amax.0.forward.input.0'] - ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key']) + ['Functional.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.conv2d.0.forward.input.0', 'input', 'Functional.conv2d.0.forward'], + ['Functional.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.amax.0.forward.input.0', 'input', 'Functional.amax.0.forward'] + ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key', 'state', 'api_origin_name']) bench_df = pd.DataFrame([ - ['Torch.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Torch.conv2d.0.forward.input.0'], - ['Torch.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Torch.amax.0.forward.input.0'] - ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key']) + ['Torch.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Torch.conv2d.0.forward.input.0', 'input', 'Functional.conv2d.0.forward'], + ['Torch.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Torch.amax.0.forward.input.0', 'input', 'Functional.amax.0.forward'] + ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key', 'state', 'api_origin_name']) process_df.modify_compare_data_with_user_mapping(npu_df, bench_df) -- Gitee