From 2153d96e52f128afa5b457623258d77ffb203162 Mon Sep 17 00:00:00 2001 From: gitee Date: Sat, 11 May 2024 09:21:32 +0800 Subject: [PATCH 1/3] fix --- .../api_accuracy_checker/run_ut/run_ut.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 3aacb9e17..a3353cc75 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -538,12 +538,12 @@ def run_ut_command(args): class UtDataInfo: - def __init__(self, bench_grad_out, device_grad_out, device_out, - bench_out, grad_in, in_fwd_data_list, backward_message, rank=0): - self.bench_grad_out = bench_grad_out - self.device_grad_out = device_grad_out - self.device_out = device_out - self.bench_out = bench_out + def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list, + backward_message, rank=0): + self.bench_grad = bench_grad + self.device_grad = device_grad + self.device_output = device_output + self.bench_output = bench_output self.grad_in = grad_in self.in_fwd_data_list = in_fwd_data_list self.backward_message = backward_message -- Gitee From 05d45fbb64b3e7c7768bc0a2534a9727c257c559 Mon Sep 17 00:00:00 2001 From: gitee Date: Sat, 11 May 2024 09:39:13 +0800 Subject: [PATCH 2/3] fix --- .../api_accuracy_checker/test/ut/run_ut/test_run_ut.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py index 3412620f7..3c180fa23 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py @@ -62,10 +62,10 @@ class TestRunUtMethods(unittest.TestCase): def test_UtDataInfo(self): data_info = UtDataInfo(None, None, None, None, None, None, None) - self.assertIsNone(data_info.bench_grad_out) - self.assertIsNone(data_info.device_grad_out) - self.assertIsNone(data_info.device_out) - self.assertIsNone(data_info.bench_out) + self.assertIsNone(data_info.bench_grad) + self.assertIsNone(data_info.device_grad) + self.assertIsNone(data_info.device_output) + self.assertIsNone(data_info.bench_output) self.assertIsNone(data_info.grad_in) self.assertIsNone(data_info.in_fwd_data_list) self.assertIsNone(data_info.backward_message) -- Gitee From a2a8af7a64f88319ad234ca96062049f03805e5e Mon Sep 17 00:00:00 2001 From: gitee Date: Sat, 11 May 2024 09:43:46 +0800 Subject: [PATCH 3/3] fix --- debug/accuracy_tools/api_accuracy_checker/compare/compare.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 3c9b5e921..30bf805e7 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -182,8 +182,8 @@ class Comparator: def compare_output(self, full_api_name, data_info): _, api_name, _ = full_api_name.split("*") - bench_output, device_output = data_info.bench_out, data_info.device_out - bench_grad, device_grad = data_info.bench_grad_out, data_info.device_grad_out + bench_output, device_output = data_info.bench_output, data_info.device_output + bench_grad, device_grad = data_info.bench_grad, data_info.device_grad backward_message = data_info.backward_message compare_func = self._compare_dropout if "dropout" in full_api_name else self._compare_core_wrapper # forward result compare -- Gitee