From bbabe4fb327859ca1ecdc5357bc85db00701379d Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Wed, 30 Jul 2025 16:13:33 +0800 Subject: [PATCH] mscompare api_mapping name process improve --- .../msprobe/core/compare/acc_compare.py | 22 ++++++++++++------- .../test/core_ut/compare/test_acc_compare.py | 12 +++++----- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index af9a518ff..53c1257b0 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -356,13 +356,17 @@ class ProcessDf: return npu_op_name def modify_compare_data_with_user_mapping(self, npu_df, bench_df): + def remove_prefix(string, prefix): + if string.startswith(prefix): + return string[len(prefix):] + return string + def gen_input_compare_key(pattern, term): is_unmatched = True for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')): - if op_name.split(pattern)[1].startswith(str(prefix)): + if remove_prefix(op_name, api_origin_name + pattern) == str(prefix): npu_df.loc[index, CompareConst.CMP_KEY] = ( - op_name.replace(pattern + str(prefix), - pattern + str(mapping_dict.get(f'pt_{term}')[i]))) + op_name.replace(pattern + str(prefix), pattern + str(mapping_dict.get(f'pt_{term}')[i]))) is_unmatched = False return is_unmatched @@ -384,15 +388,17 @@ class ProcessDf: continue for index in ms_api_indices_dict.get(ms_api): op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1) - if CompareConst.INPUT_PATTERN in op_name: + state = npu_df.loc[index, Const.STATE] + api_origin_name = npu_df.loc[index, Const.API_ORIGIN_NAME].replace(ms_api, pt_api, 1) + if state == Const.INPUT: is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args') - elif CompareConst.KWARGS_PATTERN in op_name: + elif state == Const.KWARGS: is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args') - elif CompareConst.OUTPUT_PATTERN in op_name: + elif state == Const.OUTPUT: is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output') - elif CompareConst.PARAMS_PATTERN in op_name: + elif state == Const.PARAMS: is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters') - elif CompareConst.PARAMS_GRAD_PATTERN in op_name: + elif state == Const.PARAMS_GRAD: is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad') else: logger.error(f'Excepted op_name: {op_name}') diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py index 1f7c515a5..cc120183e 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py @@ -604,13 +604,13 @@ class TestProcessDf(unittest.TestCase): }] npu_df = pd.DataFrame([ - ['Functional.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.conv2d.0.forward.input.0'], - ['Functional.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.amax.0.forward.input.0'] - ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key']) + ['Functional.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.conv2d.0.forward.input.0', 'input', 'Functional.conv2d.0.forward'], + ['Functional.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.amax.0.forward.input.0', 'input', 'Functional.amax.0.forward'] + ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key', 'state', 'api_origin_name']) bench_df = pd.DataFrame([ - ['Torch.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Torch.conv2d.0.forward.input.0'], - ['Torch.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Torch.amax.0.forward.input.0'] - ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key']) + ['Torch.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Torch.conv2d.0.forward.input.0', 'input', 'Functional.conv2d.0.forward'], + ['Torch.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Torch.amax.0.forward.input.0', 'input', 'Functional.amax.0.forward'] + ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key', 'state', 'api_origin_name']) process_df.modify_compare_data_with_user_mapping(npu_df, bench_df) -- Gitee