From f44fee3ee904e1fff91eb061278ea656a43e2679 Mon Sep 17 00:00:00 2001 From: litian_drinksnow <1063185601@qq.com> Date: Thu, 5 Sep 2024 17:20:20 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=8C=E6=AD=A5msprobe/PT=20master=E5=88=86?= =?UTF-8?q?=E6=94=AF=E5=88=B0poc=E5=88=86=E6=94=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/pytorch/__init__.py | 3 +- .../api_accuracy_checker/common/config.py | 19 +- .../api_accuracy_checker/common/utils.py | 151 ++++---- .../api_accuracy_checker/compare/algorithm.py | 7 +- .../compare/api_precision_compare.py | 123 +++++-- .../api_accuracy_checker/compare/compare.py | 140 +++++--- .../compare/compare_utils.py | 20 +- .../pytorch/api_accuracy_checker/config.yaml | 7 +- .../run_ut/data_generate.py | 22 +- .../run_ut/multi_run_ut.py | 12 +- .../run_ut/run_overflow_check.py | 13 +- .../api_accuracy_checker/run_ut/run_ut.py | 283 +++++++++------ .../run_ut/run_ut_utils.py | 65 +++- .../run_ut/torch_ut_setting.json | 3 + .../tensor_transport_layer/__init__.py | 0 .../tensor_transport_layer/attl.py | 197 +++++++++++ .../tensor_transport_layer/client.py | 325 ++++++++++++++++++ .../tensor_transport_layer/device_dispatch.py | 204 +++++++++++ .../tensor_transport_layer/server.py | 219 ++++++++++++ .../tensor_transport_layer/ssl_config.py | 10 + .../pytorch/bench_functions/apply_adam_w.py | 2 +- .../bench_functions/confusion_transpose.py | 2 +- .../pytorch/bench_functions/fast_gelu.py | 4 +- .../bench_functions/layer_norm_eval.py | 4 +- .../msprobe/pytorch/bench_functions/linear.py | 2 +- .../bench_functions/npu_fusion_attention.py | 100 +++++- .../pytorch/bench_functions/rms_norm.py | 2 +- .../pytorch/bench_functions/rotary_mul.py | 2 +- .../bench_functions/scaled_mask_softmax.py | 2 +- .../msprobe/pytorch/bench_functions/swiglu.py | 2 +- .../msprobe/pytorch/common/log.py | 11 - .../msprobe/pytorch/common/utils.py | 82 ++++- .../pytorch/compare/distributed_compare.py | 71 +--- .../msprobe/pytorch/compare/match.py | 9 +- .../msprobe/pytorch/compare/pt_compare.py | 51 +++ .../pytorch/debugger/debugger_config.py | 12 +- .../pytorch/debugger/precision_debugger.py | 66 ++-- .../pytorch/free_benchmark/__init__.py | 2 +- .../free_benchmark/compare/grad_saver.py | 48 ++- .../msprobe/pytorch/free_benchmark/main.py | 11 +- .../perturbed_layers/npu/add_noise.py | 2 +- .../perturbed_layers/npu/bit_noise.py | 2 +- .../perturbed_layers/npu/change_value.py | 2 +- .../perturbed_layers/npu/improve_precision.py | 6 +- .../perturbed_layers/npu/no_change.py | 2 +- .../perturbed_layers/run_cpu.py | 2 +- .../result_handlers/handler_factory.py | 1 - .../msprobe/pytorch/function_factory.py | 9 +- .../msprobe/pytorch/functional/dump_module.py | 2 +- .../msprobe/pytorch/grad_probe/__init__.py | 0 .../pytorch/grad_probe/grad_monitor.py | 92 +++++ .../pytorch/grad_probe/grad_stat_csv.py | 129 +++++++ .../pytorch/hook_module/hook_module.py | 11 +- .../pytorch/hook_module/support_wrap_ops.yaml | 4 +- .../msprobe/pytorch/hook_module/utils.py | 19 +- .../msprobe/pytorch/hook_module/wrap_aten.py | 17 +- .../pytorch/hook_module/wrap_distributed.py | 17 +- .../pytorch/hook_module/wrap_functional.py | 11 +- .../pytorch/hook_module/wrap_npu_custom.py | 24 +- .../pytorch/hook_module/wrap_tensor.py | 11 +- .../msprobe/pytorch/hook_module/wrap_torch.py | 12 +- .../msprobe/pytorch/hook_module/wrap_vf.py | 14 +- .../msprobe/pytorch/module_processer.py | 29 +- .../pytorch/online_dispatch/compare.py | 8 +- .../pytorch/online_dispatch/dispatch.py | 92 +++-- .../pytorch/online_dispatch/dump_compare.py | 67 +--- .../pytorch/online_dispatch/single_compare.py | 18 +- .../msprobe/pytorch/online_dispatch/utils.py | 58 +--- .../msprobe/pytorch/parse_tool/lib/compare.py | 43 ++- .../msprobe/pytorch/parse_tool/lib/config.py | 1 + .../pytorch/parse_tool/lib/parse_tool.py | 8 +- .../msprobe/pytorch/parse_tool/lib/utils.py | 123 ++----- .../pytorch/parse_tool/lib/visualization.py | 14 +- .../msprobe/pytorch/pt_config.py | 68 +++- .../accuracy_tools/msprobe/pytorch/service.py | 127 +++++-- 75 files changed, 2456 insertions(+), 897 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/grad_probe/__init__.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py diff --git a/debug/accuracy_tools/msprobe/pytorch/__init__.py b/debug/accuracy_tools/msprobe/pytorch/__init__.py index 58ab1ac35..c4e426772 100644 --- a/debug/accuracy_tools/msprobe/pytorch/__init__.py +++ b/debug/accuracy_tools/msprobe/pytorch/__init__.py @@ -1,5 +1,4 @@ from .debugger.precision_debugger import PrecisionDebugger from .common.utils import seed_all -from .compare.acc_compare import compare from .compare.distributed_compare import compare_distributed -from .visualization.graph_service import compare_graph, build_graph +from .compare.pt_compare import compare \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py index 760e7c862..14478dfef 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py @@ -1,15 +1,14 @@ import os import yaml -from msprobe.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.utils import check_file_or_directory_path +from msprobe.core.common.utils import load_yaml from msprobe.pytorch.pt_config import RunUTConfig class Config: def __init__(self, yaml_file): check_file_or_directory_path(yaml_file, False) - with FileOpen(yaml_file, 'r') as file: - config = yaml.safe_load(file) + config = load_yaml(yaml_file) self.config = {key: self.validate(key, value) for key, value in config.items()} def __getattr__(self, item): @@ -24,7 +23,13 @@ class Config: 'white_list': list, 'black_list': list, 'error_data_path': str, - 'precision': int + 'precision': int, + 'is_online': bool, + 'nfs_path': str, + 'host': str, + 'port': int, + 'rank_list': list, + 'tls_path': str } if key not in validators: raise ValueError(f"{key} must be one of {validators.keys()}") @@ -38,6 +43,10 @@ class Config: RunUTConfig.check_filter_list_config(key, value) if key == 'error_data_path': RunUTConfig.check_error_data_path_config(value) + if key == 'nfs_path': + RunUTConfig.check_nfs_path_config(value) + if key == 'tls_path': + RunUTConfig.check_tls_path_config(value) return value diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py index b6e893296..d9e46bbe9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py @@ -14,10 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -import json import os import re -import csv +from collections import namedtuple import torch @@ -29,21 +28,19 @@ else: IS_GPU = False from msprobe.pytorch.common.log import logger -from msprobe.core.common.file_check import FileChecker, FileOpen, change_mode, create_directory -from msprobe.core.common.const import Const, FileCheckConst +from msprobe.pytorch.common.utils import save_pt +from msprobe.core.common.file_check import create_directory +from msprobe.core.common.const import Const from msprobe.core.common.utils import CompareException +ApiData = namedtuple('ApiData', ['name', 'args', 'kwargs', 'result', 'step', 'rank'], + defaults=['unknown', None, None, None, 0, 0]) + class DumpException(CompareException): pass -def write_csv(data, filepath): - with FileOpen(filepath, 'a', encoding='utf-8-sig') as f: - writer = csv.writer(f) - writer.writerows(data) - - def check_object_type(check_object, allow_type): """ Function Description: @@ -59,58 +56,6 @@ def check_object_type(check_object, allow_type): raise CompareException(CompareException.INVALID_DATA_ERROR) -def check_file_or_directory_path(path, isdir=False): - """ - Function Description: - check whether the path is valid - Parameter: - path: the path to check - isdir: the path is dir or file - Exception Description: - when invalid data throw exception - """ - if isdir: - if not os.path.exists(path): - logger.error('The path {} is not exist.'.format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - - if not os.path.isdir(path): - logger.error('The path {} is not a directory.'.format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - - if not os.access(path, os.W_OK): - logger.error( - 'The path {} does not have permission to write. Please check the path permission'.format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - else: - if not os.path.isfile(path): - logger.error('{} is an invalid file or non-exist.'.format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - - if not os.access(path, os.R_OK): - logger.error( - 'The path {} does not have permission to read. Please check the path permission'.format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - - -def get_json_contents(file_path): - ops = get_file_content_bytes(file_path) - try: - json_obj = json.loads(ops) - except ValueError as error: - logger.error('Failed to load "%s". %s' % (file_path, str(error))) - raise CompareException(CompareException.INVALID_FILE_ERROR) from error - if not isinstance(json_obj, dict): - logger.error('Json file %s, content is not a dictionary!' % file_path) - raise CompareException(CompareException.INVALID_FILE_ERROR) - return json_obj - - -def get_file_content_bytes(file): - with FileOpen(file, 'rb') as file_handle: - return file_handle.read() - - class SoftlinkCheckException(Exception): pass @@ -151,33 +96,19 @@ def cross_entropy_process(api_info_dict): Return api_info_dict: api_info_dict: Processed argument of the API. """ - if 'args' in api_info_dict and len(api_info_dict['args']) > 1 and 'Min' in api_info_dict['args'][1]: - if api_info_dict['args'][1]['Min'] <= 0: + if 'input_args' in api_info_dict and len(api_info_dict['input_args']) > 1 and 'Min' in api_info_dict['input_args'][1]: + if api_info_dict['input_args'][1]['Min'] <= 0: # The second argument in cross_entropy should be -100 or not less than 0 - api_info_dict['args'][1]['Min'] = 0 + api_info_dict['input_args'][1]['Min'] = 0 return api_info_dict def initialize_save_path(save_path, dir_name): data_path = os.path.join(save_path, dir_name) - if os.path.exists(data_path): - logger.warning(f"{data_path} already exists, it will be overwritten") - else: - os.mkdir(data_path, mode=FileCheckConst.DATA_DIR_AUTHORITY) - data_path_checker = FileChecker(data_path, FileCheckConst.DIR) - data_path_checker.common_check() + create_directory(data_path) return data_path -def write_pt(file_path, tensor): - if os.path.exists(file_path): - raise ValueError(f"File {file_path} already exists") - torch.save(tensor, file_path) - full_path = os.path.realpath(file_path) - change_mode(full_path, FileCheckConst.DATA_FILE_AUTHORITY) - return full_path - - def get_real_data_path(file_path): targets = ['forward_real_data', 'backward_real_data', 'ut_error_data\d+'] pattern = re.compile(r'({})'.format('|'.join(targets))) @@ -211,7 +142,12 @@ class UtDataProcessor: api_args = api_name + Const.SEP + str(self.index) create_directory(self.save_path) file_path = os.path.join(self.save_path, f'{api_args}.pt') - write_pt(file_path, element.contiguous().cpu().detach()) + try: + tensor = element.contiguous().detach().cpu() + except Exception as err: + logger.error(f"Failed to transfer tensor to cpu for {api_args}") + raise DumpException(DumpException.INVALID_DATA_ERROR) from err + save_pt(tensor, file_path) self.index += 1 elif element is None or isinstance(element, (bool, int, float, str, slice)): self.index += 1 @@ -223,3 +159,56 @@ class UtDataProcessor: self._save_recursive(api_name, value) else: self.index += 1 + + +def extract_basic_api_segments(api_full_name): + """ + Function Description: + Extract the name of the API. + Parameter: + api_full_name: Full name of the API. Example: torch.matmul.0, torch.linalg.inv.0 + Return: + api_type: Type of api. Example: torch, tensor, etc. + api_name: Name of api. Example: matmul, linalg.inv, etc. + """ + api_type = None + api_parts = api_full_name.split(Const.SEP) + api_parts_length = len(api_parts) + if api_parts_length == Const.THREE_SEGMENT: + api_type, api_name, _ = api_parts + elif api_parts_length == Const.FOUR_SEGMENT: + api_type, prefix, api_name, _ = api_parts + api_name = Const.SEP.join([prefix, api_name]) + else: + api_name = None + return api_type, api_name + + +def extract_detailed_api_segments(full_api_name_with_direction_status): + """ + Function Description: + Extract the name of the API. + Parameter: + full_api_name_with_direction_status: Full name of the API. Example: torch.matmul.0.forward.output.0 + Return: + api_name: Name of api. Example: matmul, mul, etc. + full_api_name: Full name of api. Example: torch.matmul.0 + direction_status: Direction status of api. Example: forward, backward, etc. + """ + api_type = None + prefix = None + api_name = None + direction_status = None + api_parts = full_api_name_with_direction_status.split(Const.SEP) + api_parts_length = len(api_parts) + if api_parts_length == Const.SIX_SEGMENT: + api_type, api_name, api_order, direction_status, _, _ = api_parts + full_api_name = Const.SEP.join([api_type, api_name, api_order]) + elif api_parts_length == Const.SEVEN_SEGMENT: + api_type, prefix, api_name, api_order, direction_status, _, _ = api_parts + full_api_name = Const.SEP.join([api_type, prefix, api_name, api_order]) + api_name = Const.SEP.join([prefix, api_name]) + else: + full_api_name = None + return api_name, full_api_name, direction_status + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py index 1bb19cc04..4f7fa14d3 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py @@ -6,9 +6,6 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAM from msprobe.core.common.const import CompareConst -DEFAULT_THRESHOLD = 1 - - #cos def cosine_sim(bench_output, device_output): msg = "" @@ -197,8 +194,8 @@ def check_norm_value(normal_value_mask, rel_err, rtol): def get_ulp_err(bench_output, device_output, dtype): parameters = ULP_PARAMETERS.get(dtype) - min_eb = parameters.get('min_eb', DEFAULT_THRESHOLD)[0] - exponent_num = parameters.get('exponent_num', DEFAULT_THRESHOLD)[0] + min_eb = parameters.get('min_eb')[0] + exponent_num = parameters.get('exponent_num')[0] abs_bench = np.abs(bench_output) eb = np.where(abs_bench == 0, 0, np.floor(np.log2(abs_bench))) eb = np.maximum(eb, min_eb) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py index 73bf7c2b8..d93ad463d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -7,19 +7,20 @@ from collections import namedtuple import torch import pandas as pd -from msprobe.pytorch.api_accuracy_checker.common.utils import write_csv +from msprobe.core.common.utils import write_csv from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \ API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \ - ApiPrecisionCompareColumn, AbsoluteStandardApi, BinaryStandardApi, ULPStandardApi, ThousandthStandardApi, \ + ApiPrecisionCompareColumn, absolute_standard_api, binary_standard_api, ulp_standard_api, thousandth_standard_api, \ BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage, is_inf_or_nan, \ check_inf_or_nan from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn -from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path +from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments from msprobe.core.common.file_check import FileChecker, change_mode, check_path_before_create, create_directory from msprobe.pytorch.common.log import logger from msprobe.core.common.utils import CompareException -from msprobe.core.common.const import CompareConst, FileCheckConst +from msprobe.core.common.const import Const, CompareConst, FileCheckConst CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path']) BenchmarkInf_Nan_Consistency = namedtuple('BenchmarkInf_Nan_Consistency', ['small_value_inf_nan_consistency', @@ -289,15 +290,47 @@ def api_precision_compare(config): change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) +def online_api_precision_compare(online_config): + rank = online_config.rank + result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace("_rank*.csv", f"_rank{rank}.csv") + details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace("_rank*.csv", f"_rank{rank}.csv") + detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()] + result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()] + if not os.path.exists(result_csv_path): + write_csv(result_csv_title, result_csv_path) + if not os.path.exists(details_csv_path): + write_csv(detail_csv_title, details_csv_path) + config = CompareConfig("", "", result_csv_path, details_csv_path) + try: + npu_data, gpu_data = online_config.npu_data, online_config.gpu_data + check_csv_columns(npu_data.columns, "npu_csv") + check_csv_columns(gpu_data.columns, "gpu_csv") + analyse_csv(npu_data, gpu_data, config) + except Exception as err: + logger.error(f"Online api precision compare Error: {str(err)}") + change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) + change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) + + def analyse_csv(npu_data, gpu_data, config): forward_status, backward_status = [], [] - last_api_name, last_api_dtype = None, None + last_api_name, last_api_dtype, last_api_full_name = None, None, None for _, row_npu in npu_data.iterrows(): message = '' compare_column = ApiPrecisionOutputColumn() full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME] row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status] - _, api_name, _, direction_status, _, _ = full_api_name_with_direction_status.split(".") + api_name, api_full_name, direction_status = extract_detailed_api_segments(full_api_name_with_direction_status) + if not api_full_name: + err_message = f"The API name {full_api_name_with_direction_status} is invalid." + logger.error(err_message) + compare_column.api_name = full_api_name_with_direction_status + compare_column.compare_result = CompareConst.SKIP + compare_column.compare_message = err_message + write_detail_csv(compare_column.to_column_value(), config.details_csv_path) + write_csv([[full_api_name_with_direction_status, CompareConst.SKIP, CompareConst.SKIP, err_message]], + config.result_csv_path) + continue if row_gpu.empty: logger.warning(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.') continue @@ -306,35 +339,25 @@ def analyse_csv(npu_data, gpu_data, config): raise CompareException(CompareException.INVALID_DATA_ERROR, msg) row_gpu = row_gpu.iloc[0] new_status = CompareConst.SPACE - # 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对 - if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace(): + try: + new_status = get_api_status(row_npu, row_gpu, api_name, compare_column) + except Exception as err: + logger.error(f"Get api status error: {str(err)}") compare_column.api_name = full_api_name_with_direction_status compare_column.compare_result = CompareConst.SKIP - compare_column.compare_message = row_npu[ApiPrecisionCompareColumn.MESSAGE] - new_status = CompareConst.SKIP - write_detail_csv(compare_column.to_column_value(), config.details_csv_path) - else: - compare_column.api_name = full_api_name_with_direction_status - if api_name in ThousandthStandardApi: - new_status = record_thousandth_threshold_result(compare_column, row_npu) - elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or \ - api_name in BinaryStandardApi: - new_status = record_binary_consistency_result(api_name, compare_column, row_npu) - elif api_name in AbsoluteStandardApi: - new_status = record_absolute_threshold_result(compare_column, row_npu) - elif api_name in ULPStandardApi and \ - row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST: - us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu) - new_status = record_ulp_compare_result(compare_column, us) - elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST: - bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu) - new_status = record_benchmark_compare_result(compare_column, bs) + compare_column.compare_message = str(err) write_detail_csv(compare_column.to_column_value(), config.details_csv_path) + write_csv([[full_api_name_with_direction_status, CompareConst.SKIP, CompareConst.SKIP, str(err)]], + config.result_csv_path) + continue + + write_detail_csv(compare_column.to_column_value(), config.details_csv_path) - if last_api_name is not None and api_name != last_api_name: + if last_api_name is not None and api_full_name != last_api_name: if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST: message = unsupported_message - write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path) + write_csv([[last_api_name, CompareConst.SKIP, CompareConst.SKIP, message]], config.result_csv_path) + print_test_success(last_api_name, CompareConst.SKIP, CompareConst.SKIP) forward_status, backward_status = [], [] message = '' else: @@ -342,11 +365,12 @@ def analyse_csv(npu_data, gpu_data, config): backward_result = get_api_checker_result(backward_status) message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else "" write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path) + print_test_success(last_api_name, forward_result, backward_result) forward_status, backward_status = [], [] message = '' is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST - last_api_name = api_name + last_api_name = api_full_name last_api_dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] if not is_supported: @@ -362,12 +386,49 @@ def analyse_csv(npu_data, gpu_data, config): if last_api_name is not None: if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST: message = unsupported_message - write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path) + write_csv([[last_api_name, CompareConst.SKIP, CompareConst.SKIP, message]], config.result_csv_path) + print_test_success(last_api_name, CompareConst.SKIP, CompareConst.SKIP) else: forward_result = get_api_checker_result(forward_status) backward_result = get_api_checker_result(backward_status) message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else "" write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path) + print_test_success(last_api_name, forward_result, backward_result) + + +def get_api_status(row_npu, row_gpu, api_name, compare_column): + full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME] + # 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对 + if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace(): + compare_column.api_name = full_api_name_with_direction_status + compare_column.compare_result = CompareConst.SKIP + compare_column.compare_message = row_npu[ApiPrecisionCompareColumn.MESSAGE] + new_status = CompareConst.SKIP + else: + compare_column.api_name = full_api_name_with_direction_status + if api_name in thousandth_standard_api: + new_status = record_thousandth_threshold_result(compare_column, row_npu) + elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or \ + api_name in binary_standard_api: + new_status = record_binary_consistency_result(api_name, compare_column, row_npu) + elif api_name in absolute_standard_api: + new_status = record_absolute_threshold_result(compare_column, row_npu) + elif api_name in ulp_standard_api and \ + row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST: + us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu) + new_status = record_ulp_compare_result(compare_column, us) + elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST: + bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu) + new_status = record_benchmark_compare_result(compare_column, bs) + return new_status + + +def print_test_success(api_full_name, forward_result, backward_result): + is_fwd_success = (forward_result == CompareConst.PASS) + is_bwd_success = (backward_result == CompareConst.PASS or backward_result == CompareConst.SPACE) + logger.info(f"running api_full_name {api_full_name} compare, " + f"is_fwd_success: {is_fwd_success}, " + f"is_bwd_success: {is_bwd_success}") def check_error_rate(npu_error_rate): diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index ee4958828..5408dd24f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -1,27 +1,29 @@ # 进行比对及结果展示 import os from collections import namedtuple -import torch + import numpy as np -from msprobe.pytorch.common.log import logger -from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents, write_csv -from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \ - DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, \ - ULPStandardApi, ThousandthStandardApi, apis_threshold -from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn +from msprobe.core.common.utils import write_csv, get_json_contents, CompareException +import torch +from msprobe.core.common.const import Const, CompareConst from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \ get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \ get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \ check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig -from msprobe.core.common.const import Const, CompareConst +from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn +from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \ + DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, absolute_standard_api, binary_standard_api, \ + ulp_standard_api, thousandth_standard_api, apis_threshold +from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments +from msprobe.pytorch.common.log import logger ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status', 'fwd_compare_alg_results', 'bwd_compare_alg_results', 'rank']) -INDEX_TEST_RESULT__GROUP = 3 +INDEX_TEST_RESULT_GROUP = 3 INDEX_FIRST_GROUP = 0 INDEX_MESSAGE = -1 @@ -33,20 +35,34 @@ class Comparator: COLUMN_BACKWARD_SUCCESS = "Backward Test Success" COLUMN_STACK_INFO = "Traceback callstack info" - def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None): - self.save_path = result_csv_path - self.detail_save_path = details_csv_path - if not is_continue_run_ut and not os.path.exists(self.save_path) and not os.path.exists(self.detail_save_path): + def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None, config=None): + self.save_path_str = result_csv_path + self.detail_save_path_str = details_csv_path + self.save_path_list = [result_csv_path] + self.detail_save_path_list = [details_csv_path] + + if config and config.online_config.is_online: + self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv") + self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv") + self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list] + self.detail_save_path_list = \ + [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list] + + if not is_continue_run_ut: self.write_csv_title() if stack_info_json_path: self.stack_info = get_json_contents(stack_info_json_path) else: self.stack_info = None + @staticmethod + def get_path_from_rank(rank, path_list, path_pattern): + return path_list[-1] if len(path_list) == 1 else path_pattern.format(rank) + @staticmethod def print_pretest_result(): logger.info("Successfully completed run_ut/multi_run_ut.") - + @staticmethod def _compare_dropout(bench_output, device_output): tensor_num = bench_output.numel() @@ -75,7 +91,7 @@ class Comparator: error_rate = float(error_nums / bench_output.size) result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR return error_rate, result, "" - + @staticmethod def _get_absolute_threshold_attribute(api_name, dtype): small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value') @@ -83,35 +99,18 @@ class Comparator: rtol = apis_threshold.get(api_name).get(dtype).get('rtol') return small_value_threshold, small_value_atol, rtol - def write_csv_title(self): - summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, - self.COLUMN_BACKWARD_SUCCESS, "Message"]] - if not os.path.exists(self.save_path): - write_csv(summary_test_rows, self.save_path) - if not os.path.exists(self.detail_save_path): - write_csv(DETAIL_TEST_ROWS, self.detail_save_path) - - def write_summary_csv(self, test_result): - test_rows = [] - if self.stack_info: - test_rows[0].append(self.COLUMN_STACK_INFO) - - name = test_result[0] - df_row = list(test_result[:INDEX_TEST_RESULT__GROUP]) - if test_result[1] == "SKIP": - df_row.append(test_result[INDEX_TEST_RESULT__GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE]) - if self.stack_info: - stack_info = "\n".join(self.stack_info[name]) - df_row.append(stack_info) - test_rows.append(df_row) - write_csv(test_rows, self.save_path) - - def write_detail_csv(self, test_result): + @staticmethod + def _get_run_ut_detail(test_result): + """get run_ut detail before write to csv, called by online run_ut""" test_rows = [] + try: + subject_prefix = test_result[0] + fwd_result = test_result[3] + bwd_result = test_result[4] + except IndexError as e: + logger.error("List index out of bounds when writing detail CSV.") + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e - subject_prefix = test_result[0] - fwd_result = test_result[3] - bwd_result = test_result[4] if isinstance(fwd_result, list): for i, test_subject in enumerate(fwd_result): subject = subject_prefix + ".forward.output." + str(i) @@ -124,15 +123,53 @@ class Comparator: test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision) if isinstance(item, float) else item for item in test_subject] test_rows.append([subject] + list(test_subject)) + return test_rows - write_csv(test_rows, self.detail_save_path) + def write_csv_title(self): + summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, + self.COLUMN_BACKWARD_SUCCESS, "Message"]] + for save_path, detail_save_path in zip(self.save_path_list, self.detail_save_path_list): + if not os.path.exists(save_path): + write_csv(summary_test_rows, save_path) + if not os.path.exists(detail_save_path): + write_csv(DETAIL_TEST_ROWS, detail_save_path) + + def write_summary_csv(self, test_result): + test_rows = [] + try: + name = test_result[0] + df_row = list(test_result[:INDEX_TEST_RESULT_GROUP]) + if test_result[1] == CompareConst.SKIP: + df_row.append(test_result[INDEX_TEST_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE]) + if self.stack_info: + stack_info = "\n".join(self.stack_info[name]) + df_row.append(stack_info) + test_rows.append(df_row) + save_path = self.get_path_from_rank(test_result[-1], self.save_path_list, self.save_path_str) + except IndexError as e: + logger.error("List index out of bounds when writing summary CSV.") + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e + write_csv(test_rows, save_path) + + def write_detail_csv(self, test_result): + test_rows = self._get_run_ut_detail(test_result) + detail_save_path = self.get_path_from_rank(test_result[-1], + self.detail_save_path_list, + self.detail_save_path_str) + write_csv(test_rows, detail_save_path) def record_results(self, args): self.write_summary_csv(args) self.write_detail_csv(args) - def compare_output(self, full_api_name, data_info): - _, api_name, _ = full_api_name.split(Const.SEP) + + def compare_output(self, full_api_name, data_info, is_online=False): + """Get compare result and write to result and detail csv. + is_online: bool, default False. True: called by online api precision compare, only compare without write to csv. + """ + _, api_name = extract_basic_api_segments(full_api_name) + if not api_name: + raise ValueError(f"API name {full_api_name} has not been adapted.") bench_output, device_output = data_info.bench_output, data_info.device_output bench_grad, device_grad = data_info.bench_grad, data_info.device_grad backward_message = data_info.backward_message @@ -160,6 +197,9 @@ class Comparator: fwd_compare_alg_results, bwd_compare_alg_results, data_info.rank) + if is_online: + # get run_ut compare detail + return self._get_run_ut_detail(result_info) self.record_results(result_info) return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \ or bwd_success_status == CompareConst.SPACE @@ -234,7 +274,7 @@ class Comparator: if npu_dtype == torch.bfloat16: bench_output = bench_output.to(torch.float32) device_output = device_output.to(torch.float32) - bench_output = bench_output.numpy() + bench_output = bench_output.cpu().numpy() device_output = device_output.cpu().numpy() if cpu_shape != npu_shape: return CompareConst.ERROR, compare_column, f"The shape of bench{str(cpu_shape)} " \ @@ -261,15 +301,15 @@ class Comparator: abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype) abs_err = get_abs_err(bench_output, device_output) rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps) - if api_name in ThousandthStandardApi: + if api_name in thousandth_standard_api: thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD) compare_column.rel_err_thousandth = thousand_res if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output) - if api_name in BinaryStandardApi: + if api_name in binary_standard_api: err_rate, _, _ = self._compare_bool_tensor(bench_output, device_output) compare_column.error_rate = err_rate - elif api_name in AbsoluteStandardApi: + elif api_name in absolute_standard_api: small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute( api_name, str(dtype)) rel_err = abs_err / abs_bench_with_eps @@ -279,7 +319,7 @@ class Comparator: dtype, rtol) compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol) compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol) - elif api_name in ULPStandardApi: + elif api_name in ulp_standard_api: if bench_output.size == 0: compare_column.max_ulp_error = 0 compare_column.mean_ulp_error = 0 diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py index 5c7e86ff3..4c2b921e1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py @@ -2,13 +2,11 @@ import time import os import math -import numpy as np import torch -import yaml -from msprobe.core.common.utils import CompareException + +from msprobe.core.common.utils import CompareException, load_yaml from msprobe.core.common.const import Const from msprobe.pytorch.common.log import logger -from msprobe.core.common.file_check import FileOpen current_time = time.strftime("%Y%m%d%H%M%S") @@ -22,17 +20,15 @@ BINARY_COMPARE_UNSUPPORT_LIST = BENCHMARK_COMPARE_SUPPORT_LIST + API_PRECISION_C cur_path = os.path.dirname(os.path.realpath(__file__)) standard_yaml_path = os.path.join(cur_path, "api_precision_standard.yaml") -with FileOpen(standard_yaml_path, 'r') as f: - Apis = yaml.safe_load(f) - AbsoluteStandardApi = Apis.get('AbsoluteThreshStandard') - BinaryStandardApi = Apis.get('BinaryCompareStandard') - ULPStandardApi = Apis.get('ULPStandard') - ThousandthStandardApi = Apis.get('ThousandthStandard') +apis = load_yaml(standard_yaml_path) +absolute_standard_api = apis.get('AbsoluteThreshStandard') +binary_standard_api = apis.get('BinaryCompareStandard') +ulp_standard_api = apis.get('ULPStandard') +thousandth_standard_api = apis.get('ThousandthStandard') threshold_yaml_path = os.path.join(cur_path, "api_precision_threshold.yaml") -with FileOpen(threshold_yaml_path, 'r') as f: - apis_threshold = yaml.safe_load(f) +apis_threshold = load_yaml(threshold_yaml_path) DETAIL_TEST_ROWS = [[ diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml index 2dac535dc..49f8a726d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml @@ -2,4 +2,9 @@ white_list: [] black_list: [] error_data_path: './' precision: 14 - \ No newline at end of file +is_online: False +nfs_path: "" +host: "" +port: -1 +rank_list: [0] +tls_path: "" diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index b2eec691a..7812ad93b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -25,6 +25,8 @@ from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, CompareException from msprobe.core.common.file_check import FileChecker from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.utils import load_pt +from msprobe.core.common.utils import load_npy from msprobe.core.common.const import Const, FileCheckConst TORCH_TYPE = ["torch.device", "torch.dtype"] @@ -76,6 +78,8 @@ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None): data = info.get('value') if info.get("type") == "slice": data = slice(*data) + if info.get("type") == "ellipsis": + data = ... return data @@ -94,9 +98,9 @@ def gen_real_tensor(data_path, convert_type): error_info = f"The file: {data_path} is not a pt or numpy file." raise CompareException(CompareException.INVALID_FILE_ERROR, error_info) if data_path.endswith('.pt'): - data = torch.load(data_path, map_location=torch.device('cpu')) + data = load_pt(data_path, to_cpu=True) else: - data_np = numpy.load(data_path) + data_np = load_npy(data_path) data = torch.from_numpy(data_np) if convert_type: ori_dtype = Const.CONVERT.get(convert_type)[0] @@ -257,12 +261,13 @@ def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_p return args_result -def gen_kwargs(api_info, convert_type=None, real_data_path=None): +def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None): """ Function Description: Based on API basic information, generate input parameters: kwargs, for API forward running Parameter: api_info: API basic information. Dict + api_name: API name convert_type: convert ori_type to dist_type flag. real_data_path: the root directory for storing real data. """ @@ -270,11 +275,11 @@ def gen_kwargs(api_info, convert_type=None, real_data_path=None): kwargs_params = api_info.get("input_kwargs") for key, value in kwargs_params.items(): if isinstance(value, (list, tuple)): - kwargs_params[key] = gen_list_kwargs(value, convert_type, real_data_path) + kwargs_params[key] = gen_list_kwargs(value, api_name, convert_type, real_data_path) elif value is None: kwargs_params[key] = None elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"): - kwargs_params[key] = gen_data(value, True, convert_type, real_data_path) + kwargs_params[key] = gen_data(value, api_name, True, convert_type, real_data_path) elif value.get('type') in TORCH_TYPE: gen_torch_kwargs(kwargs_params, key, value) else: @@ -287,18 +292,19 @@ def gen_torch_kwargs(kwargs_params, key, value): kwargs_params[key] = eval(value.get('value')) -def gen_list_kwargs(kwargs_item_value, convert_type, real_data_path=None): +def gen_list_kwargs(kwargs_item_value, api_name, convert_type, real_data_path=None): """ Function Description: When kwargs value is list, generate the list of kwargs result Parameter: kwargs_item_value: kwargs value before to generate. List + api_name: API name convert_type: convert ori_type to dist_type flag. """ kwargs_item_result = [] for item in kwargs_item_value: if item.get('type') in TENSOR_DATA_LIST: - item_value = gen_data(item, False, convert_type, real_data_path) + item_value = gen_data(item, api_name, False, convert_type, real_data_path) elif item.get('type') == "torch.Size": item_value = torch.Size(item.get('value')) else: @@ -321,7 +327,7 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d if convert_type and convert_type not in Const.CONVERT: error_info = f"convert_type params not support {convert_type}." raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info) - kwargs_params = gen_kwargs(api_info, convert_type, real_data_path) + kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path) if api_info.get("input_args"): args_params = gen_args(api_info.get("input_args"), api_name, need_grad, convert_type, real_data_path) else: diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py index 9acb5ee64..8334abef8 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py @@ -9,13 +9,15 @@ import threading from collections import namedtuple from itertools import cycle from tqdm import tqdm -from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_result_csv_path, \ - get_validated_details_csv_path, preprocess_forward_content +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, preprocess_forward_content +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path, \ + get_validated_details_csv_path from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator from msprobe.pytorch.common import parse_json_info_forward_backward +from msprobe.pytorch.common.log import logger from msprobe.core.common.file_check import FileChecker, check_file_suffix, check_link, FileOpen, \ check_path_before_create, create_directory -from msprobe.pytorch.common.log import logger +from msprobe.core.common.utils import remove_path from msprobe.core.common.const import FileCheckConst @@ -117,7 +119,7 @@ def run_parallel_ut(config): for api_info in config.api_files: cmd = create_cmd(api_info, next(device_id_cycle)) - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, bufsize=1) + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, bufsize=1, shell=False) processes.append(process) threading.Thread(target=read_process_output, args=(process,), daemon=True).start() @@ -135,7 +137,7 @@ def run_parallel_ut(config): for file in config.api_files: check_link(file) try: - os.remove(file) + remove_path(file) except FileNotFoundError: logger.warning(f"File not found and could not be deleted: {file}") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py index 732745ee8..1575a77df 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py @@ -10,8 +10,9 @@ else: is_gpu = False import torch from tqdm import tqdm -from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, get_api_info -from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import generate_device_params, get_api_info +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api +from msprobe.core.common.utils import get_json_contents from msprobe.core.common.file_check import check_link from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward @@ -81,8 +82,12 @@ def run_torch_api(api_full_name, api_info_dict, real_data_path): npu_args, npu_kwargs = generate_device_params(args, kwargs, False, api_name) if kwargs.get("device"): del kwargs["device"] - out = exec_api(api_type, api_name, args, kwargs) - npu_out = exec_api(api_type, api_name, npu_args, npu_kwargs) + out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, args, kwargs) + npu_out = exec_api(api_type, api_name, Const.NPU_LOWERCASE, npu_args, npu_kwargs) + if out is None and npu_out is None: + logger.warning("The %s overflow is a normal overflow, out and npu_out is None." % api_full_name) + return + cpu_overflow = check_data_overflow(out) npu_overflow = torch_npu.npu.utils.npu_check_overflow(npu_out) if cpu_overflow == npu_overflow: diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py index 559dfdc0f..e54982b12 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -1,7 +1,6 @@ import argparse import os import csv -import re import sys import time import gc @@ -18,32 +17,35 @@ else: import torch from tqdm import tqdm -from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api, UtDataInfo, \ + get_validated_result_csv_path, get_validated_details_csv_path, exec_api from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args -from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents, api_info_preprocess, \ - initialize_save_path, UtDataProcessor +from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \ + initialize_save_path, UtDataProcessor, extract_basic_api_segments, ApiData from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn -from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate -from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate -from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate -from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate -from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward from msprobe.core.common.file_check import FileOpen, FileChecker, \ - change_mode, check_file_suffix, check_link, check_path_before_create, create_directory + change_mode, check_path_before_create, create_directory from msprobe.pytorch.common.log import logger +from msprobe.core.common.utils import get_json_contents from msprobe.pytorch.pt_config import parse_json_config from msprobe.core.common.const import Const, FileCheckConst, CompareConst +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher + current_time = time.strftime("%Y%m%d%H%M%S") UT_ERROR_DATA_DIR = 'ut_error_data' + current_time RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv" RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path', - 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list', - 'black_list', 'error_data_path']) + 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list', + 'black_list', 'error_data_path', 'online_config']) + +OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path']) + not_backward_list = ['repeat_interleave'] not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} not_raise_dtype_set = {'type_as'} @@ -70,25 +72,6 @@ tqdm_params = { } -def exec_api(api_type, api_name, args, kwargs): - if api_type == "Functional": - functional_api = FunctionalOPTemplate(api_name, str, False) - out = functional_api.forward(*args, **kwargs) - if api_type == "Tensor": - tensor_api = TensorOPTemplate(api_name, str, False) - out = tensor_api.forward(*args, **kwargs) - if api_type == "Torch": - torch_api = TorchOPTemplate(api_name, str, False) - out = torch_api.forward(*args, **kwargs) - if api_type == "Aten": - torch_api = AtenOPTemplate(api_name, None, False) - out = torch_api.forward(*args, **kwargs) - if api_type == "NPU": - torch_api = NpuOPTemplate(api_name, None, False) - out = torch_api.forward(*args, **kwargs) - return out - - def deal_detach(arg, to_detach=True): return arg.detach() if to_detach else arg @@ -140,7 +123,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): elif isinstance(arg_in, torch.Tensor): if need_backward and arg_in.requires_grad: arg_in = deal_detach(raise_bench_data_dtype( - api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_() + api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_() temp_arg_in = arg_in * 1 arg_in = temp_arg_in.type_as(arg_in) arg_in.retain_grad() @@ -183,25 +166,50 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): def run_ut(config): logger.info("start UT test") - logger.info(f"UT task result will be saved in {config.result_csv_path}") - logger.info(f"UT task details will be saved in {config.details_csv_path}") + if config.online_config.is_online: + logger.info(f"UT task result will be saved in {config.result_csv_path}".replace(".csv", "_rank*.csv")) + logger.info(f"UT task details will be saved in {config.details_csv_path}".replace(".csv", "_rank*.csv")) + else: + logger.info(f"UT task result will be saved in {config.result_csv_path}") + logger.info(f"UT task details will be saved in {config.details_csv_path}") + if config.save_error_data: logger.info(f"UT task error_datas will be saved in {config.error_data_path}") - compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut) - with FileOpen(config.result_csv_path, 'r') as file: - csv_reader = csv.reader(file) - next(csv_reader) - api_name_set = {row[0] for row in csv_reader} + compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config) + + if config.online_config.is_online: + run_api_online(config, compare) + else: + with FileOpen(config.result_csv_path, 'r') as file: + csv_reader = csv.reader(file) + next(csv_reader) + api_name_set = {row[0] for row in csv_reader} + run_api_offline(config, compare, api_name_set) + for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list): + change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) + change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) + logger.info(f"UT task result csv is saved in {result_csv_path}") + logger.info(f"UT task details csv is saved in {details_csv_path}") + compare.print_pretest_result() + + +def run_api_offline(config, compare, api_name_set): + err_column = CompareColumn() for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)): if api_full_name in api_name_set: continue - if is_unsupported_api(api_full_name): # TODO run_ut does not support to the npu fusion api and distributed api + if is_unsupported_api(api_full_name): + continue + _, api_name = extract_basic_api_segments(api_full_name) + if not api_name: + err_message = f"API {api_full_name} not support for run ut. SKIP." + logger.error(err_message) + fwd_compare_alg_results = err_column.to_column_value(CompareConst.SKIP, err_message) + result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [fwd_compare_alg_results], None, 0) + compare.record_results(result_info) continue - [_, api_name, _] = api_full_name.split(Const.SEP) try: - if config.black_list and api_name in config.black_list: - continue - if config.white_list and api_name not in config.white_list: + if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list): continue data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict) is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info) @@ -213,7 +221,6 @@ def run_ut(config): f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") else: logger.error(f"Run {api_full_name} UT Error: %s" % str(err)) - err_column = CompareColumn() fwd_compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err)) result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [fwd_compare_alg_results], None, 0) compare.record_results(result_info) @@ -223,14 +230,78 @@ def run_ut(config): else: torch.npu.empty_cache() gc.collect() - change_mode(compare.save_path, FileCheckConst.DATA_FILE_AUTHORITY) - change_mode(compare.detail_save_path, FileCheckConst.DATA_FILE_AUTHORITY) - compare.print_pretest_result() + + +def run_api_online(config, compare): + attl = init_attl(config.online_config) + dispatcher = ConsumerDispatcher(compare=compare) + dispatcher.start(handle_func=run_torch_api_online, config=config) + + def tcp_communication_flow(): + while True: + api_data = attl.recv() + if api_data == 'STOP_': + continue + if api_data == 'KILL_': + time.sleep(1) + logger.info("==========接收到STOP信号==========") + dispatcher.stop() + attl.stop_serve() + time.sleep(1) + break + if not isinstance(api_data, ApiData): + continue + api_full_name = api_data.name + _, api_name = extract_basic_api_segments(api_full_name) + if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list): + continue + if api_data.rank in config.online_config.rank_list: + dispatcher.update_consume_queue(api_data) + + def shared_storage_communication_flow(): + flag_num = -1 + while True: + api_data = attl.download() + if api_data == "start": + if flag_num == -1: + flag_num += 1 + flag_num += 1 + if api_data == "end": + flag_num -= 1 + if flag_num == 0: + dispatcher.stop() + break + if not isinstance(api_data, ApiData): + continue + api_full_name = api_data.name + _, api_name = extract_basic_api_segments(api_full_name) + if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list): + continue + if api_data.rank in config.online_config.rank_list: + dispatcher.update_consume_queue(api_data) + + if config.online_config.nfs_path: + shared_storage_communication_flow() + else: + tcp_communication_flow() + + +def blacklist_and_whitelist_filter(api_name, black_list, white_list): + """ + run api(api_name) if api_name not in black_list and in white_list. + If api is both in black_list and black_list, black_list first. + return: False for exec api, True for not exec + """ + if black_list and api_name in black_list: + return True + if white_list and api_name not in white_list: + return True + return False def is_unsupported_api(api_name): split_name = api_name.split(Const.SEP)[0] - flag = split_name in [Const.NPU, Const.DISTRIBUTED] + flag = split_name == Const.DISTRIBUTED if flag: logger.info(f"{split_name} api is not supported for run ut. SKIP.") return flag @@ -251,7 +322,7 @@ def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict): in_fwd_data_list = [] backward_message = '' - [api_type, api_name, _] = api_full_name.split(Const.SEP) + api_type, api_name = extract_basic_api_segments(api_full_name) args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path) in_fwd_data_list.append(args) in_fwd_data_list.append(kwargs) @@ -269,8 +340,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward, api_name) device_args, device_kwargs = generate_device_params(args, kwargs, need_backward, api_name) bench_grad_out, device_grad_out = None, None - out = exec_api(api_type, api_name, cpu_args, cpu_kwargs) - device_out = exec_api(api_type, api_name, device_args, device_kwargs) + out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs) + device_out = exec_api(api_type, api_name, current_device, device_args, device_kwargs) current_path = os.path.dirname(os.path.realpath(__file__)) ut_setting_path = os.path.join(current_path, "torch_ut_setting.json") api_setting_dict = get_json_contents(ut_setting_path) @@ -290,10 +361,27 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict device_grad_out = run_backward(device_args, device_grad, grad_index, device_out) else: backward_message += Backward_Message.MULTIPLE_BACKWARD_MESSAGE + if api_name == "npu_fusion_attention": + out = out[0] + device_out = device_out[0] return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message) +def run_torch_api_online(api_full_name, api_data, backward_content): + in_fwd_data_list = [] + api_type, api_name = extract_basic_api_segments(api_full_name) + args, kwargs, out = api_data.args, api_data.kwargs, api_data.result + in_fwd_data_list.append(args) + in_fwd_data_list.append(kwargs) + if kwargs.get("device"): + del kwargs["device"] + + device_out = exec_api(api_type, api_name, Const.CUDA_LOWERCASE, args, kwargs) + device_out = move2device_exec(device_out, "cpu") + return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank) + + def get_api_info(api_info_dict, api_name, real_data_path): convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict) need_grad = True @@ -333,35 +421,21 @@ def initialize_save_error_data(error_data_path): return error_data_path -def get_validated_result_csv_path(result_csv_path, mode): - if mode not in ['result', 'detail']: - raise ValueError("The csv mode must be result or detail") - result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE, - file_type=FileCheckConst.CSV_SUFFIX) - validated_result_csv_path = result_csv_path_checker.common_check() - if mode == 'result': - result_csv_name = os.path.basename(validated_result_csv_path) - pattern = r"^accuracy_checking_result_\d{14}\.csv$" - if not re.match(pattern, result_csv_name): - raise ValueError("When continue run ut, please do not modify the result csv name.") - return validated_result_csv_path - - -def get_validated_details_csv_path(validated_result_csv_path): - result_csv_name = os.path.basename(validated_result_csv_path) - details_csv_name = result_csv_name.replace('result', 'details') - details_csv_path = os.path.join(os.path.dirname(validated_result_csv_path), details_csv_name) - details_csv_path_checker = FileChecker(details_csv_path, FileCheckConst.FILE, - ability=FileCheckConst.READ_WRITE_ABLE, file_type=FileCheckConst.CSV_SUFFIX) - validated_details_csv_path = details_csv_path_checker.common_check() - return validated_details_csv_path +def init_attl(config): + """config: OnlineConfig""" + attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True, + connect_ip=config.host, + connect_port=config.port, + nfs_path=config.nfs_path, + tls_path=config.tls_path)) + return attl def _run_ut_parser(parser): parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str, - help=" The api param tool result file: generate from api param tool, " + help=" The api param tool result file: generate from api param tool, " "a json file.", - required=True) + required=False) parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, help=" The ut task result out path.", required=False) @@ -451,20 +525,26 @@ def run_ut_command(args): except Exception as error: logger.error(f"Set device id failed. device id is: {args.device_id}") raise NotImplementedError from error - check_link(args.api_info_file) - api_info = os.path.realpath(args.api_info_file) - check_file_suffix(api_info, FileCheckConst.JSON_SUFFIX) + + # 在线预检场景下,不需要外出输出api信息,forward_content, backward_content, real_data_path设置为None + # 离线场景下,forward_content, backward_content, real_data_path从api_info_file中解析 + forward_content, backward_content, real_data_path = None, None, None + if args.api_info_file: + api_info_file_checker = FileChecker(file_path = args.api_info_file, path_type = FileCheckConst.FILE, + ability = FileCheckConst.READ_ABLE, file_type = FileCheckConst.JSON_SUFFIX) + checked_api_info = api_info_file_checker.common_check() + forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info) + if args.filter_api: + logger.info("Start filtering the api in the forward_input_file.") + forward_content = preprocess_forward_content(forward_content) + logger.info("Finish filtering the api in the forward_input_file.") + out_path = os.path.realpath(args.out_path) if args.out_path else "./" check_path_before_create(out_path) create_directory(out_path) out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) out_path = out_path_checker.common_check() save_error_data = args.save_error_data - forward_content, backward_content, real_data_path = parse_json_info_forward_backward(api_info) - if args.filter_api: - logger.info("Start filtering the api in the forward_input_file.") - forward_content = preprocess_forward_content(forward_content) - logger.info("Finish filtering the api in the forward_input_file.") result_csv_path = os.path.join(out_path, RESULT_FILE_NAME) details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME) @@ -474,35 +554,40 @@ def run_ut_command(args): white_list = msCheckerConfig.white_list black_list = msCheckerConfig.black_list error_data_path = msCheckerConfig.error_data_path + is_online = msCheckerConfig.is_online + nfs_path = msCheckerConfig.nfs_path + host = msCheckerConfig.host + port = msCheckerConfig.port + rank_list = msCheckerConfig.rank_list + tls_path = msCheckerConfig.tls_path if args.config_path: - _, task_config = parse_json_config(args.config_path, Const.RUN_UT) + config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE, + FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX) + checked_config_path = config_path_checker.common_check() + _, task_config = parse_json_config(checked_config_path, Const.RUN_UT) white_list = task_config.white_list black_list = task_config.black_list error_data_path = task_config.error_data_path + is_online = task_config.is_online + nfs_path = task_config.nfs_path + host = task_config.host + port = task_config.port + rank_list = task_config.rank_list + tls_path = task_config.tls_path + if save_error_data: if args.result_csv_path: time_info = result_csv_path.split('.')[0].split('_')[-1] global UT_ERROR_DATA_DIR UT_ERROR_DATA_DIR = 'ut_error_data' + time_info error_data_path = initialize_save_error_data(error_data_path) + online_config = OnlineConfig(is_online, nfs_path, host, port, rank_list, tls_path) run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data, - args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path) + args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path, + online_config) run_ut(run_ut_config) -class UtDataInfo: - def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list, - backward_message, rank=0): - self.bench_grad = bench_grad - self.device_grad = device_grad - self.device_output = device_output - self.bench_output = bench_output - self.grad_in = grad_in - self.in_fwd_data_list = in_fwd_data_list - self.backward_message = backward_message - self.rank = rank - - if __name__ == '__main__': _run_ut() logger.info("UT task completed.") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index d78642f21..2ded677d7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -1,7 +1,70 @@ +import os +import re + +from msprobe.core.common.const import FileCheckConst +from msprobe.core.common.file_check import FileChecker +from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate +from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate +from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate +from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate +from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate + hf_32_standard_api = ["conv1d", "conv2d"] class Backward_Message: MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported." UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, skip backward." - NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward." \ No newline at end of file + NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward." + + +class UtDataInfo: + def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list, + backward_message, rank=0): + self.bench_grad = bench_grad + self.device_grad = device_grad + self.device_output = device_output + self.bench_output = bench_output + self.grad_in = grad_in + self.in_fwd_data_list = in_fwd_data_list + self.backward_message = backward_message + self.rank = rank + + +def get_validated_result_csv_path(result_csv_path, mode): + if mode not in ['result', 'detail']: + raise ValueError("The csv mode must be result or detail") + result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE, + file_type=FileCheckConst.CSV_SUFFIX) + validated_result_csv_path = result_csv_path_checker.common_check() + if mode == 'result': + result_csv_name = os.path.basename(validated_result_csv_path) + pattern = r"^accuracy_checking_result_\d{14}\.csv$" + if not re.match(pattern, result_csv_name): + raise ValueError("When continue run ut, please do not modify the result csv name.") + return validated_result_csv_path + + +def get_validated_details_csv_path(validated_result_csv_path): + result_csv_name = os.path.basename(validated_result_csv_path) + details_csv_name = result_csv_name.replace('result', 'details') + details_csv_path = os.path.join(os.path.dirname(validated_result_csv_path), details_csv_name) + details_csv_path_checker = FileChecker(details_csv_path, FileCheckConst.FILE, + ability=FileCheckConst.READ_WRITE_ABLE, file_type=FileCheckConst.CSV_SUFFIX) + validated_details_csv_path = details_csv_path_checker.common_check() + return validated_details_csv_path + + +def exec_api(api_type, api_name, device, args, kwargs): + if api_type == "Functional": + torch_api = FunctionalOPTemplate(api_name, str, False) + if api_type == "Tensor": + torch_api = TensorOPTemplate(api_name, str, False) + if api_type == "Torch": + torch_api = TorchOPTemplate(api_name, str, False) + if api_type == "Aten": + torch_api = AtenOPTemplate(api_name, None, False) + if api_type == "NPU": + torch_api = NpuOPTemplate(api_name, None, False, device) + out = torch_api.forward(*args, **kwargs) + return out diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json index d8df6098b..1eb8b192a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json @@ -1,5 +1,8 @@ { "topk": { "grad_index": 0 + }, + "npu_fusion_attention": { + "grad_index": 0 } } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py new file mode 100644 index 000000000..e74eebb19 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py @@ -0,0 +1,197 @@ +import glob +import os.path +import time +import re +from multiprocessing import Queue +from typing import Optional, Union, Dict, Any +from dataclasses import dataclass + +import torch + +from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer +from msprobe.pytorch.common.utils import logger +from msprobe.core.common.utils import remove_path +from msprobe.pytorch.common.utils import save_api_data, load_api_data, save_pt, load_pt + +BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]] + + +@dataclass +class ATTLConfig: + is_benchmark_device: bool + connect_ip: str + connect_port: int + # storage_config + nfs_path: str = None + tls_path: str = None + check_sum: bool = True + queue_size: int = 50 + + +class ATTL: + def __init__(self, session_id: str, session_config: ATTLConfig, need_dump=True) -> None: + self.session_id = session_id + self.session_config = session_config + self.logger = logger + self.socket_manager = None + self.data_queue = Queue(maxsize=50) + self.dequeue_list = [] + self.message_end = False + self.kill_progress = False + self.check_attl_config() + if self.session_config.nfs_path: + self.nfs_path = self.session_config.nfs_path + elif self.session_config.is_benchmark_device: + + self.socket_manager = TCPServer(self.session_config.connect_port, + self.data_queue, + self.session_config.check_sum, + self.session_config.tls_path) + self.socket_manager.start() + elif need_dump: + self.socket_manager = TCPClient(self.session_config.connect_ip, + self.session_config.connect_port, + self.session_config.check_sum, + self.session_config.tls_path) + self.socket_manager.start() + + def check_attl_config(self): + if self.session_config.nfs_path: + if os.path.exists(self.session_config.nfs_path): + return + else: + raise Exception(f"nfs path {self.session_config.nfs_path} doesn't exists.") + ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$" + if not re.match(ipv4_pattern, self.session_config.connect_ip): + raise Exception(f"host {self.session_config.connect_ip} is invalid.") + if not (0 < self.session_config.connect_port <= 65535): + raise Exception(f"port {self.session_config.connect_port} is invalid.") + + def stop_serve(self): + if isinstance(self.socket_manager, TCPServer): + self.socket_manager.stop() + + def send(self, buffer: BufferType) -> None: + """ + npu major in 'send' (client) + """ + # know receiver receive and go next + if isinstance(buffer, ApiData): + buffer = move2target_device(buffer, torch.device('cpu')) + + if 'device' in buffer.kwargs: + buffer.kwargs.pop('device') + rank = buffer.rank if hasattr(buffer, "rank") and buffer.rank is not None else 0 + step = buffer.step if hasattr(buffer, "step") else 0 + try: + io_buff = save_api_data(buffer) + except Exception as e: + self.logger.info(f"{buffer.name} can not be saved, skip: {e}") + return + data = io_buff.getvalue() + self.socket_manager.add_to_sending_queue(data, rank=rank, step=step) + + def recv(self, timeout_ms=0) -> Optional[BufferType]: + buffer = None + while buffer is None: + if timeout_ms > 0: + time.sleep(timeout_ms / 1000.0) + if buffer is None and not self.data_queue.empty(): + buffer = self.data_queue.get() + break + if buffer is None and timeout_ms > 0: # timeout is the only case we give up and return None + break + if self.message_end and self.data_queue.empty(): + buffer = b"KILL_CONFIRM" + self.kill_progress = True + break + time.sleep(0.1) # waiting outside the lock before next attempt + if buffer is None: + # this is a result of a timeout + self.logger.info(f"RECEIVE API DATA TIMED OUT") + else: + if buffer == b"STOP_": + return "STOP_" + if buffer == b"KILL_": + self.message_end = True + return "STOP_" + if buffer == b"KILL_CONFIRM": + self.kill_progress = True + return "KILL_" + try: + buffer = load_api_data(buffer) + except Exception as e: + self.logger.warning("there is something error. please check it. %s", e) + if isinstance(buffer, bytes): + return None + if isinstance(buffer, str): + return buffer + + return buffer + + def upload(self, buffer: BufferType): + if isinstance(buffer, ApiData): + buffer = move2target_device(buffer, torch.device('cpu')) + file_path = os.path.join(self.session_config.nfs_path, buffer.name + ".pt") + else: + file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}") + + try: + save_pt(buffer, file_path) + except Exception as e: + self.logger.warning("there is something error in save_pt. please check it. %s", e) + + def download(self): + buffer = None + cur_file = None + for file_type in ("start*", "*.pt", "end*"): + pattern = os.path.join(self.nfs_path, file_type) + files = glob.glob(pattern) + if len(files) > 0: + cur_file = files[0] + break + + if cur_file is not None: + try: + buffer = load_pt(cur_file) + except Exception as e: + self.logger.warning("there is something error. please check it. %s", e) + remove_path(cur_file) + return buffer + + +def move2device_exec(obj, device): + if isinstance(obj, (tuple, list)): + data_list = [move2device_exec(val, device) for val in obj] + return data_list if isinstance(obj, list) else tuple(data_list) + if isinstance(obj, dict): + return {key: move2device_exec(val, device) for key, val in obj.items()} + elif isinstance(obj, torch.Tensor): + obj = obj.detach() + if obj.device.type != device: + obj = obj.to(device) + return obj + elif "return_types" in str(type(obj)): + return move2device_exec(tuple(obj), device) + elif isinstance(obj, torch._C.device): + return torch.device(device) + else: + return obj + + +def move2target_device(buffer: ApiData, target_device): + # handle args + new_args = move2device_exec(buffer.args, target_device) + + # handle kwargs + new_kwargs = move2device_exec(buffer.kwargs, target_device) + + # handle result + new_results = move2device_exec(buffer.result, target_device) + + if target_device == torch.device('cpu') or target_device == "cpu": + return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results, buffer.step, buffer.rank) + else: + return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py new file mode 100644 index 000000000..9d5090c5e --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py @@ -0,0 +1,325 @@ +import hashlib +import io +import struct +import time +import os +import signal +import sys +from queue import Queue +from threading import Thread +from typing import Union + +from twisted.internet import reactor, protocol, endpoints +from twisted.protocols.basic import FileSender + +from msprobe.pytorch.common.utils import logger +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.ssl_config import cipher_list + + +class TCPDataItem: + def __init__(self, data, + sequence_number: int, + rank: int = 0, + step: int = 0): + self.raw_data = data + self.sequence_number = sequence_number + self.rank = rank + self.step = step + self.retry_times = 0 + self.pending_time = 0 + self.busy_time = 0 + + +class TCPClient: + MAX_SENDING_QUEUE_SIZE = 20 + ACK_SUCCESS = b"OK___" + ACK_ERROR = b"ERROR" + ACK_BUSY = b"BUSY_" + ACK_STOP = b"STOP_" + ACK_STOP_CONFIRM = b"OVER_" + ACK_KILL_PROCESS = b"KILL_" + + QUEUE_PENDING_TIME = 600 # 队列10分钟都处于阻塞状态,则终止sending进程 + RESEND_RETRY_TIMES = 2 # 最大重传数 + RESEND_TIMER_TIME = 5 # 接收ACK超时定时器 + RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据 + + def __init__(self, host="localhost", port=8000, check_sum=False, tls_path=None): + self.send_queue = Queue(self.MAX_SENDING_QUEUE_SIZE) + self.resend_dict = dict() + self.host = host + self.port = port + self.tls_path = tls_path + self.factory = None + self.sequence_number = 0 + self.signal_exit = False + self.tcp_manager = ClientProtocol(ack_queue_size=100, + chunk_size=655360, + check_sum=check_sum) + self.send_thread = Thread(target=self._sending_queue_data) + self.send_thread.setDaemon(True) + self.send_thread.start() + self.destroy_thread = Thread(target=self._destroy_queue_data) + self.destroy_thread.setDaemon(True) + self.destroy_thread.start() + + @staticmethod + def run_reactor(): + reactor.run(installSignalHandlers=False) + + def start(self): + def conn_callback(cur_protocol): + if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host: + logger.debug(f"Process: {os.getpid()} connects to server successfully.") + else: + logger.warning(f"Process: {os.getpid()} fails to connect to server. ") + raise ConnectionError(f"Failed to connect to {self.host}.") + + def conn_err_callback(failure): + self.signal_exit = True + time.sleep(1) + reactor.stop() + logger.error(f"Failed to connected {self.host} {self.port}. Reason is {failure.getErrorMessage()}") + os.kill(os.getpid(), signal.SIGKILL) + os.kill(os.getppid(), signal.SIGKILL) + + def cur_protocol(): + return self.tcp_manager + + self.factory = MessageClientFactory() + self.factory.protocol = cur_protocol + if self.tls_path: + from OpenSSL import SSL + from twisted.internet import ssl + client_key = os.path.join(self.tls_path, "client.key") + client_crt = os.path.join(self.tls_path, "client.crt") + client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt, SSL.TLSv1_2_METHOD) + client_context_ = client_context_factory.getContext() + client_context_.set_cipher_list(cipher_list) + client_context_.set_options(SSL.OP_NO_RENEGOTIATION) + endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory) + else: + endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port) + d = endpoint.connect(self.factory) + d.addCallback(conn_callback) + d.addErrback(conn_err_callback) + + reactor_thread = Thread(target=self.run_reactor, daemon=True) + reactor_thread.start() + + def send_after_queue_empty(self, data): + while not self._ready_to_exit(): + self.add_to_sending_queue(data) + time.sleep(2) + + def check_client_alive(self): + return self.factory.num_connections > 0 + + def stop(self): + self.tcp_manager.connection_timeout() + + def send_stop_signal(self): + self.send_after_queue_empty(self.ACK_STOP) + while not self._ready_to_exit(): + if not self.check_client_alive(): + break + time.sleep(1) + while not self.tcp_manager.kill_process: + time.sleep(1) + + def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], rank: int = 0, step: int = 0): + if self._ready_to_exit(): + return + + send_data = data + if not isinstance(data, TCPDataItem): + send_data = TCPDataItem(data=data, + sequence_number=self.sequence_number, + rank=rank, + step=step) + self.sequence_number += 1 + try: + self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME) + except Exception as e: + logger.error(f"send_queue put send_data timeout, rank: {send_data.rank}, step: {send_data.step}," + f"sequence_number: {send_data.sequence_number}, {str(e)}") + + def _send_data(self, data: TCPDataItem): + self.tcp_manager.send_wrapped_data(data.raw_data, + sequence_number=data.sequence_number, + rank=data.rank, + step=data.step + ) + + def _sending_queue_data(self): + while True: + if not self.tcp_manager.is_connected: + continue + + while self.send_queue.qsize() > 0: + if self._ready_to_exit(): + break + if len(self.resend_dict) < self.MAX_SENDING_QUEUE_SIZE: + data_obj = self.send_queue.get() + self._send_data(data_obj) + resend_key = str(data_obj.sequence_number) + "_" + str(data_obj.rank) + "_" + str(data_obj.step) + if resend_key not in self.resend_dict.keys(): + # Send data for the first time + self.resend_dict[resend_key] = data_obj + else: + time.sleep(0.1) + + if self._ready_to_exit(): + logger.debug("Successfully close sending process.") + break + time.sleep(0.1) + + def _destroy_queue_data(self): + while True: + if self._ready_to_exit(): + break + + while len(self.resend_dict) > 0 and self.tcp_manager.ack_queue.qsize() > 0: + ack_info, seq_number, rank, step = self.tcp_manager.ack_queue.get() + obj_key = str(seq_number) + "_" + str(rank) + "_" + str(step) + current_item = self.resend_dict.get(obj_key) + + if current_item is None: + continue + + if ack_info == self.ACK_SUCCESS: + self.resend_dict.pop(obj_key) + elif ack_info == self.ACK_BUSY: + logger.debug("RECV BUSY ACK") + if current_item.busy_time > 5: + self._resend_data(current_item) + else: + current_item.busy_time += 1 + elif ack_info == self.ACK_ERROR: + logger.debug("RECV ERROR ACK") + self._resend_data(current_item) + elif ack_info == self.ACK_STOP_CONFIRM: + logger.debug("RECV STOP ACK") + self.factory.num_connections -= 1 + + break + + time.sleep(0.1) + + def _resend_data(self, data: TCPDataItem): + if data.retry_times < self.RESEND_RETRY_TIMES: + data.retry_times += 1 + logger.debug(f"Resend data seq number: {data.sequence_number}") + self.add_to_sending_queue(data) + else: + self.resend_dict.pop(data.sequence_number) + logger.debug(f"SKIP send sequence number {data.sequence_number} after retry {data.retry_times} times!") + + def _pending_data(self, data: TCPDataItem): + if data.pending_time >= self.RESEND_PENDING_TIME: + self.resend_dict.pop(data.sequence_number) + logger.debug(f"SKIP send sequence number {data.sequence_number} after pending {data.pending_time} times!") + return + + # wait time is 100MB per second + pending_time = max(1, len(data.raw_data) // (2 ** 20 * 50)) + data.pending_time += pending_time + time.sleep(pending_time) + + def _ready_to_exit(self): + return self.signal_exit or self.tcp_manager.signal_exit + + +class ClientProtocol(protocol.Protocol): + TIMEOUT = 60 * 10 + + def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False): + self.buffer = io.BytesIO() + self.is_connected = False + self.check_sum = check_sum + self.tell = 0 + self.ack_queue = Queue(maxsize=ack_queue_size) + self.file_sender = FileSender() + self.file_sender.CHUNK_SIZE = chunk_size + self.signal_exit = False + self.defer = None + self.kill_process = False + + def dataReceived(self, data): + if self.timeout_call.active(): + self.timeout_call.reset(self.TIMEOUT) + + self.buffer.seek(0, 2) + self.buffer.write(data) + self.buffer.seek(self.tell) + while True: + if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3 + ack = self.buffer.read(5) + seq_number = struct.unpack('!Q', self.buffer.read(8))[0] + rank = struct.unpack('!Q', self.buffer.read(8))[0] + step = struct.unpack('!Q', self.buffer.read(8))[0] + if ack == b"KILL_": + self.kill_process = True + logger.debug(f"接收到KILL信号, PID {os.getpid()}") + if ack == b"OVER_": + self.factory.num_connections -= 1 + self.tell += 29 + if not self.ack_queue.full(): + self.ack_queue.put((ack, seq_number, rank, step)) + self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:]) + self.tell = 0 + else: + time.sleep(0.1) + else: + break + + def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0): + length = len(data) + md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else "" + while True: + if self.defer is None or self.defer.called: + self.defer = self.send_large_data( + length.to_bytes(8, byteorder='big') + + sequence_number.to_bytes(8, byteorder='big') + + rank.to_bytes(8, byteorder='big') + + step.to_bytes(8, byteorder='big') + + md5_hash.encode() + + data) + break + time.sleep(0.01) + + def send_large_data(self, data): + d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport) + return d + + def connection_timeout(self): + if self.factory.num_connections <= 0: + return + + self.factory.num_connections -= 1 + logger.debug(f"超时退出{self.transport.addr}, PID {os.getpid()}") + self.transport.loseConnection() + + def connectionMade(self): + self.timeout_call = reactor.callLater(self.TIMEOUT, self.connection_timeout) + self.is_connected = True + self.factory.num_connections += 1 + logger.info("successfully connect server") + + def connectionLost(self, reason): + self.signal_exit = True + self.factory.num_connections -= 1 + logger.info(f"Lost connection with server, reason is : {reason}") + + +class MessageClientFactory(protocol.ClientFactory): + def __init__(self): + self.num_connections = 0 + + def clientConnectionFailed(self, connector, reason): + logger.info(f"Fail to connection with server: {reason.getErrorMessage()}") + reactor.stop() + + def clientConnectionLost(self, connector, reason): + logger.info(f"Client lost connection with server: {reason.getErrorMessage()}") + reactor.stop() diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py new file mode 100644 index 000000000..0865a8817 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py @@ -0,0 +1,204 @@ +import time +from collections import namedtuple + +import pandas as pd +import torch +import torch.multiprocessing as mp + +from msprobe.core.common.const import Const, CompareConst +from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import online_api_precision_compare +from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import DETAIL_TEST_ROWS, thousandth_standard_api, \ + binary_standard_api, absolute_standard_api +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device + +# NPU vs GPU api list +CompareApi = set(absolute_standard_api) | set(binary_standard_api) | set(thousandth_standard_api) + +current_time = time.strftime("%Y%m%d%H%M%S") +ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME = "api_precision_compare_result_" + current_time + "_rank*.csv" +ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME = "api_precision_compare_details_" + current_time + "_rank*.csv" + +OnlineApiPrecisionCompareConfig = namedtuple('OnlineApiPrecisionCompareConfig', + ['npu_data', 'gpu_data', 'rank', 'result_csv_path', 'details_csv_path']) +# namedtuple of [instance of Comparator, func of run_touch_api_online, config of run_ut_config] +CommonCompareConfig = namedtuple('CommonCompareConfig', ['compare', 'handle_func', 'config']) + + +def run_ut_process(xpu_id, consumer_queue, common_config, api_precision_csv_file): + """ When consumer_queue(shared with ConsumerDispatcher) is not empty, consume api data from consumer_queue. + :param xpu_id: int + :param consumer_queue: shared queues of ConsumerDispatcher + :param common_config: namedtuple of CommonCompareConfig + :param api_precision_csv_file: list, length is 2, result file name and details file name + :return: + """ + gpu_device = torch.device(f'cuda:{xpu_id}') + + while True: + if consumer_queue.empty(): + time.sleep(0.1) + continue + + api_data = consumer_queue.get() + if api_data == "KILL_": + # current consumer finish + return + + _, api_name, _ = api_data.name.split(Const.SEP) + if api_name in CompareApi: + # NPU vs GPU + online_compare(api_data, gpu_device, common_config) + else: + # NPUvsCPU vs GPUvsCPU + online_precision_compare(api_data, gpu_device, common_config, api_precision_csv_file) + + +def online_precision_compare(api_data, device, common_config, api_precision_csv_file): + """online run_ut for precision_compare: NPUvsCPU vs GPUvsCPU + 1. get NPUvsCPU compare result + 2. get GPUvsCPU compare result + 3. call online_api_precision_compare + :param api_data + :param device + :param common_config: namedtuple of CommonCompareConfig + :param api_precision_csv_file: [result_file_name, details_file_name] + """ + compare, func, config = common_config.compare, common_config.handle_func, common_config.config + api_full_name = api_data.name + [api_type, api_name, _] = api_full_name.split(Const.SEP) + npu_args, npu_kwargs, npu_out = api_data.args, api_data.kwargs, api_data.result + + if npu_kwargs.get("device"): + del npu_kwargs["device"] + + try: + # NPU vs CPU + cpu_out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, npu_args, npu_kwargs) + npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank) + npu_detail = compare.compare_output(api_full_name, npu_data_info, True) + npu_data = pd.DataFrame(npu_detail, columns=DETAIL_TEST_ROWS[-1]) + + # GPU vs CPU + api_data_gpu = move2target_device(api_data, device) # args, kwargs -> gpu, result -> npu + data_info = func(api_full_name, api_data_gpu, config.backward_content) + gpu_out = data_info.bench_output + gpu_data_info = UtDataInfo(None, None, gpu_out, cpu_out, None, [], None, rank=api_data.rank) + gpu_detail = compare.compare_output(api_full_name, gpu_data_info, True) + gpu_data = pd.DataFrame(gpu_detail, columns=DETAIL_TEST_ROWS[-1]) + + # NPUvsCPU vs GPUvsCPU + result_file_name, details_file_name = api_precision_csv_file + precision_compare_config = OnlineApiPrecisionCompareConfig(npu_data, gpu_data, api_data.rank, + result_file_name, details_file_name) + online_api_precision_compare(precision_compare_config) + + except Exception as err: + if "expected scalar type Long" in str(err): + logger.warning( + f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " + f"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.") + elif api_type in [Const.DISTRIBUTED]: + logger.info(f"{api_full_name} is not supported for run ut. SKIP.") + else: + logger.error(f"Run {api_full_name} UT Error: {str(err)}") + + compare.write_summary_csv((api_full_name, CompareConst.SKIP, CompareConst.SKIP, [[str(err)]], api_data.rank)) + + finally: + torch.cuda.empty_cache() + + +def online_compare(api_data, device, common_config): + """online run_ut for compare:NPU vs GPU + """ + compare, func, config = common_config.compare, common_config.handle_func, common_config.config + api_full_name = api_data.name + api_data = move2target_device(api_data, device) + try: + data_info = func(api_full_name, api_data, config.backward_content) + is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info) + logger.info(f"running api_full_name {api_full_name} ut, " + f"is_fwd_success: {is_fwd_success}, " + f"is_bwd_success: {is_bwd_success}") + except Exception as err: + [api_type, api_name, _] = api_full_name.split(Const.SEP) + if "expected scalar type Long" in str(err): + logger.warning( + f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " + f"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.") + elif api_type in [Const.DISTRIBUTED]: + logger.info(f"{api_full_name} is not supported for run ut. SKIP.") + else: + logger.error(f"Run {api_full_name} UT Error: {str(err)}") + + compare.write_summary_csv((api_full_name, CompareConst.SKIP, CompareConst.SKIP, [[str(err)]], api_data.rank)) + + finally: + torch.cuda.empty_cache() + + +class ConsumerDispatcher: + def __init__(self, compare, capacity=10, num_workers=8, device: str = "gpu") -> None: + self.num_workers = num_workers + self.capacity = capacity + self.compare = compare + self.queues = [] + self.processes = [] + self.reverse_sort = False + self.pool = None + self.device = device + self.data_id = 0 + self.lock = mp.Lock() + self.result_queue = mp.Queue() + mp.set_start_method("spawn", force=True) + + def start(self, handle_func, config): + self.queues = [mp.Queue(maxsize=self.capacity) for _ in range(self.num_workers)] + api_precision_csv_file = [ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME, ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME] + common_config = CommonCompareConfig(self.compare, handle_func, config) + for xpu_id, q in enumerate(self.queues): + p = mp.Process(name="run_ut_process", target=run_ut_process, + args=(xpu_id, q, common_config, api_precision_csv_file)) + + p.start() + self.processes.append(p) + logger.info(f"Api_precision_compare task result will be saved in {ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME}") + logger.info(f"Api_precision_compare task details will be saved in {ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME}") + logger.info("Successfully start unittest process.") + + def stop(self): + for q in self.queues: + while q.full(): + time.sleep(0.1) + q.put("KILL_") + + for p in self.processes: + p.join() + logger.info("Successfully stop unittest process.") + logger.info(f"Api_precision_compare task result is saved in {ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME}") + logger.info(f"Api_precision_compare task details is saved in {ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME}") + + def update_consume_queue(self, api_data): + while True: + index = self._choose_max_empty_site_strategy() + if index != -1: + q = self.queues[index] + q.put(api_data) + break + time.sleep(0.1) + + def _choose_max_empty_site_strategy(self): + maximum = 0 + index = -1 + # 充分利用多卡资源,防止任务过多分配给前面的卡 + _reverse = 1 if not self.reverse_sort else -1 + for i, q in enumerate(self.queues[::_reverse]): + empty_site = self.capacity - q.qsize() + if empty_site > maximum: + maximum = empty_site + index = i + index = len(self.queues) - index - 1 if index != -1 and self.reverse_sort else index + self.reverse_sort = not self.reverse_sort + return index diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py new file mode 100644 index 000000000..93040d07f --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py @@ -0,0 +1,219 @@ +import os.path +import struct +import hashlib +import time +import io +from threading import Thread + +from twisted.internet import reactor, protocol, endpoints + +from msprobe.pytorch.common.utils import logger +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.ssl_config import cipher_list + + +class TCPServer: + def __init__(self, port, shared_queue, check_sum=False, tls_path=None) -> None: + self.port = port + self.shared_queue = shared_queue + self.check_sum = check_sum + self.tls_path = tls_path + self.factory = MessageServerFactory() + self.reactor_thread = None + + @staticmethod + def run_reactor(): + reactor.run(installSignalHandlers=False) + + def start(self): + self.factory.protocol = self.build_protocol + + if self.tls_path: + from OpenSSL import SSL + from twisted.internet import ssl + server_key = os.path.join(self.tls_path, "server.key") + server_crt = os.path.join(self.tls_path, "server.crt") + server_context_factory = ssl.DefaultOpenSSLContextFactory(server_key, server_crt, SSL.TLSv1_2_METHOD) + server_context_ = server_context_factory.getContext() + server_context_.set_cipher_list(cipher_list) + server_context_.set_options(SSL.OP_NO_RENEGOTIATION) + endpoint = endpoints.SSL4ServerEndpoint(reactor, self.port, server_context_factory) + else: + endpoint = endpoints.TCP4ServerEndpoint(reactor, self.port) + endpoint.listen(self.factory) + self.reactor_thread = Thread(target=self.run_reactor, daemon=True) + self.reactor_thread.start() + + def is_running(self): + return not self.factory.is_all_connection_closed() + + def stop(self): + self.factory.doStop() + reactor.callFromThread(reactor.sigInt, 2) + self.reactor_thread.join() + + def build_protocol(self): + return ServerProtocol(self.shared_queue, self.check_sum) + + +class ServerProtocol(protocol.Protocol): + ACK_SUCCESS = b"OK___" + ACK_ERROR = b"ERROR" + ACK_BUSY = b"BUSY_" + ACK_STOP = b"STOP_" + ACK_STOP_CONFIRM = b"OVER_" + ACK_KILL_PROCESS = b"KILL_" + + def __init__(self, shared_queue, check_sum=False): + self.start_time = None + self.buffer = io.BytesIO() + self.consumer_queue = shared_queue + self.check_sum = check_sum + self.length_width = 8 + self.md5_width = 32 + self.obj_length = None + self.tell = 0 + self.obj_md5 = None + self.obj_body = None + self.sequence_number = -1 + self.rank = -1 + self.step = -1 + self.sequence_number_dict = dict() + + def connectionMade(self): + self.buffer = io.BytesIO() + self.obj_length = None + self.tell = 0 + self.obj_md5 = None + self.obj_body = None + self.factory.transport_dict[self.transport] = 1 + self.factory.transport_list.append(self.transport) + logger.info(f"Connected to {self.transport.getPeer()} successfully.") + + def connectionLost(self, reason): + self.factory.transport_dict.pop(self.transport, None) + if len(self.factory.transport_dict) == 0: + self.consumer_queue.put(self.ACK_KILL_PROCESS) + + logger.info(f"Lost connection with {self.transport.getPeer()}. Reason is: {reason} 与客户端 断开连接, " + f"current connection number is: {len(self.factory.transport_dict)}") + + def send_ack(self, ack_info): + ack_message = b"".join([ + ack_info, + self.sequence_number.to_bytes(8, byteorder='big'), + self.rank.to_bytes(8, byteorder='big'), + self.step.to_bytes(8, byteorder='big') + ]) + self.transport.write(ack_message) + + def post_process(self): + send_busy_ack = False + while self.consumer_queue.full(): + if not send_busy_ack: + self.send_ack(self.ACK_BUSY) + logger.debug("sending BUSY ACK") + send_busy_ack = True + time.sleep(0.1) + + obj_key = str(self.sequence_number) + "_" + str(self.rank) + "_" + str(self.step) + + recv_md5 = hashlib.md5(self.obj_body).hexdigest() + if self.check_sum and recv_md5 != self.obj_md5: + # when needs check md5 and check no pass, indicates received data error, send b"ERROR" to client. + logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_md5}, but get {recv_md5}") + self.send_ack(self.ACK_ERROR) + else: + if self.obj_body == self.ACK_STOP: + self.handle_with_stop() + else: + self.send_ack(self.ACK_SUCCESS) + if obj_key in self.sequence_number_dict: + logger.debug(f"这是一次异常的重传,可以忽略。 {obj_key}, {self.sequence_number_dict}") + else: + self.sequence_number_dict[obj_key] = self.obj_md5 + self.consumer_queue.put(self.obj_body, block=True) + + self.reset_env() + finish_time = time.time() + logger.debug(f"finish_time: {finish_time - self.start_time}") + + def handle_with_stop(self): + logger.debug(f"接收到停止传输信号 TCP{self.transport.getPeer()}") + self.send_ack(self.ACK_STOP_CONFIRM) + if len(self.factory.transport_dict) == 0: + _rank, _step, _sequence_number = 0, 0, 100000000 + ack_kill = self.ACK_KILL_PROCESS + \ + _sequence_number.to_bytes(8, byteorder='big') + \ + _rank.to_bytes(8, byteorder='big') + \ + _step.to_bytes(8, byteorder='big') + for trans in self.factory.transport_list: + trans.write(ack_kill) + logger.debug(f"发送KILL信息给{self.transport.getPeer()}") + self.consumer_queue.put(self.ACK_KILL_PROCESS) + time.sleep(2) + + def reset_env(self): + self.obj_length = None + self.sequence_number = -1 + self.rank = -1 + self.step = -1 + self.obj_md5 = None + self.obj_body = None + + def dataReceived(self, data): + self.buffer.seek(0, 2) + self.buffer.write(data) + self.buffer.seek(self.tell) + + # The first data packet is packet header, it contains obj_length, sequence_number, rank, step + if self.obj_length is None and len(self.buffer.getvalue()) >= self.length_width * 4: + self.start_time = time.time() + self.obj_length = struct.unpack('!Q', self.buffer.read(self.length_width))[0] + self.sequence_number = struct.unpack('!Q', self.buffer.read(self.length_width))[0] + self.rank = struct.unpack('!Q', self.buffer.read(self.length_width))[0] + self.step = struct.unpack('!Q', self.buffer.read(self.length_width))[0] + self.tell += self.length_width * 4 + logger.debug( + f"流水号: {self.sequence_number}; RANK: {self.rank}; STEP: {self.step}; Length: {self.obj_length}") + + # If needs check md5 but not parse md5 yet, read 32b md5 values + check_sum_and_md5 = (self.check_sum + and self.obj_length is not None + and self.obj_md5 is None + and len(self.buffer.getvalue()) - self.tell >= self.md5_width) + if check_sum_and_md5: + self.obj_md5 = self.buffer.read(self.md5_width).decode() + self.tell += self.md5_width + logger.debug(f"MD5: {self.obj_md5}") + + current_length = len(self.buffer.getvalue()) - self.tell + if self.obj_length is not None and 0 < self.obj_length <= current_length: + # Current api data receive finished + self.obj_body = self.buffer.read(self.obj_length) + + self.tell += self.obj_length + self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:]) + self.buffer.seek(0) + self.tell = 0 + recv_data_time = time.time() + logger.debug(f"self.sequence_number {self.sequence_number} " + f"recv_data_time {recv_data_time - self.start_time}") + + if self.obj_body == self.ACK_STOP: + # Indicates the current TCP link receives a STOP signal and remove from the transport_dict + _transport = self.factory.transport_dict.pop(self.transport, None) + logger.debug(f"接收到b'STOP_' self.sequence_number {self.sequence_number} ") + self.post_process() + + +class MessageServerFactory(protocol.ServerFactory): + def __init__(self) -> None: + """ + transport_dict: links that have not completed data transmission. + transport_list: Records all TCP links. Appends TCP link to the transport list when a new TCP link is established. + """ + self.transport_dict = {} + self.transport_list = [] + + def is_all_connection_closed(self): + return len(self.transport_dict) == 0 diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py new file mode 100644 index 000000000..b6e815e63 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py @@ -0,0 +1,10 @@ +cipher_list = ":".join([ + 'ECDHE-ECDSA-AES128-GCM-SHA256', + 'ECDHE-RSA-AES128-GCM-SHA256', + 'ECDHE-ECDSA-AES256-GCM-SHA384', + 'ECDHE-RSA-AES256-GCM-SHA384', + 'ECDHE-ECDSA-CHACHA20-POLY1305', + 'ECDHE-RSA-CHACHA20-POLY1305', + 'DHE-RSA-AES128-GCM-SHA256', + 'DHE-RSA-AES256-GCM-SHA384' +]).encode() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py index caf21a604..ce0a999b1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py @@ -25,4 +25,4 @@ def npu_apply_adam_w(beta1_power, beta2_power, lr, weight_decay, if (1 - beta1_power_out) == 0: beta1_power_out -= eps var_out = var_t + torch.div(-lr * m_out, (1 - beta1_power_out)).div(denom) - return var_out.cpu(), m_out.cpu(), v_out.cpu() + return var_out, m_out, v_out diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py index 627bf11b6..bbc45deef 100644 --- a/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py @@ -3,7 +3,7 @@ def npu_confusion_transpose(data, perm, shape, transpose_first): output = data.permute(*perm).contiguous().view(shape) else: output = data.view(shape).permute(*perm) - return output.cpu() + return output def npu_confusion_transpose_backward(grad, perm, shape, transpose_first): diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py index a1a9ca080..d86f371a4 100644 --- a/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py @@ -1,7 +1,7 @@ import torch -def fast_gelu(input0): +def npu_fast_gelu(input0): attr = 1.702 const_0 = 0 - attr const_1 = 1 @@ -19,7 +19,7 @@ def fast_gelu(input0): div_down_rec = torch.reciprocal(div_down) result = div_up * div_down_rec - return result.cpu() + return result def npu_fast_gelu_backward(grad, input_x): diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py index f6949c079..517c3121e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py @@ -1,6 +1,6 @@ import torch -def npu_layer_norm_eval(data, normalized_shape): +def npu_layer_norm_eval(data, normalized_shape, weight=None, bias=None, eps=1e-5): result = torch.nn.functional.layer_norm(data, normalized_shape) - return result.cpu() + return result diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py index 95db875ed..2caa25893 100644 --- a/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py @@ -3,7 +3,7 @@ import torch def npu_linear(x, weight, bias): output = torch.nn.functional.linear(x, weight, bias) - return output.cpu() + return output def npu_linear_backward(grad, input_data, weight): diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py index 63f1fa2a3..530a12a93 100644 --- a/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py @@ -1,8 +1,22 @@ import torch import numpy as np from einops import rearrange +try: + import torch_npu +except ImportError: + is_gpu = True + try: + # flash_attn为gpu的fa三方库 + from flash_attn import flash_attn_func + except ImportError: + #如果为cpu的ut环境,则不做任何处理 + pass +else: + is_gpu = False + from msprobe.pytorch.common.utils import logger +from msprobe.core.common.const import Const, CompareConst gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86 softmax_build_mode = "QKV" # "MAX_SUM" @@ -43,8 +57,8 @@ def softmax_grad(dp, softmax_res): def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype): - if num_kv_heads == 0 or num_kv_heads < num_heads: - raise ValueError(f"num_kv_heads must be non-zero and less than num_heads.") + if num_kv_heads == 0 or num_kv_heads > num_heads: + raise ValueError(f"num_kv_heads must be non-zero and bigger than num_heads.") factor = num_heads // num_kv_heads kv_shape = kv_tensor.shape @@ -266,12 +280,32 @@ def rebuild_softmax_by_max_sum(q, k, atten_mask, pse, scale, softmax_max, softma return softmax_res +def get_head_num(*args, **kwargs): + if kwargs.get("head_num", None): + head_num = kwargs.get("head_num") + elif len(args) >= 4: + head_num = args[3] + else: + raise ValueError(f"Unsupported npu_fusion_attention args {args}.") + return head_num + + +def get_input_layout(*args, **kwargs): + if kwargs.get("input_layout", None): + input_layout = kwargs.get("input_layout") + elif len(args) >= 5: + input_layout = args[4] + else: + raise ValueError(f"Unsupported npu_fusion_attention args {args}.") + return input_layout + + def npu_fusion_attention_forward_patch(*args, **kwargs): # query, key, value, head_num, input_layout - if len(args) != 5: - raise ValueError(f"Unsupported npu_fusion_attention args {args}.") + head_num = get_head_num(*args, **kwargs) + input_layout = get_input_layout(*args, **kwargs) - B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[3], args[4]) + B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], head_num, input_layout) if N1 == N2 and S1 == S2: logger.debug(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}") else: @@ -332,7 +366,8 @@ def npu_fusion_attention_backward_patch(*args, **kwargs): def npu_fusion_attention(*args, **kwargs): new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs) - query, key, value, input_layout = new_args[0], new_args[1], new_args[2], new_args[4] + query, key, value = new_args[0], new_args[1], new_args[2] + input_layout = get_input_layout(*args, **kwargs) N1 = dims_kwargs.get("N1") N2 = dims_kwargs.get("N2") S1 = dims_kwargs.get("S1") @@ -419,3 +454,56 @@ def npu_fusion_attention_grad(*args, **kwargs): dv = convert_from_bnsd(dv, input_layout) return dq.cpu(), dk.cpu(), dv.cpu() + + +def is_attention_off_due_to_mask(atten_mask_dtype): + return not atten_mask_dtype + + +def is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, S1): + return sparse_mode == 4 and (next_tockens != 0 or pre_tockens < S1) + + +def is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, S1, S2): + return sparse_mode == 0 and pre_tockens >= S1 and next_tockens >= S2 + + +def gpu_fusion_attention(*args, **kwargs): + deterministic = False + new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs) + query, key, value = new_args[0], new_args[1], new_args[2] + keep_prob = new_kwargs.get("keep_prob", 1.0) + scale = new_kwargs.get("scale") + N1 = dims_kwargs.get("N1") + N2 = dims_kwargs.get("N2") + S1 = dims_kwargs.get("S1") + S2 = dims_kwargs.get("S2") + B = dims_kwargs.get("B") + pse = new_kwargs.get("pse") + sparse_mode = new_kwargs.get("sparse_mode") + pre_tockens = new_kwargs.get("pre_tockens") + next_tockens = new_kwargs.get("next_tockens") + attn_mask = new_kwargs.get("atten_mask") + atten_mask_dtype = attn_mask.dtype if new_kwargs.get("atten_mask") is not None else None + pre_tockens = min(CompareConst.MAX_TOKENS, pre_tockens) + next_tockens = min(CompareConst.MAX_TOKENS, next_tockens) + atten_off = (is_attention_off_due_to_mask(atten_mask_dtype) or + is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, S1) or + is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, S1, S2)) + causal_switch = not atten_off + if sparse_mode == CompareConst.SPECIAL_SPARSE_MOED: + window_left = pre_tockens + window_right = next_tockens + else: + pre_tockens = next_tockens = CompareConst.MAX_TOKENS + window_left = pre_tockens - S1 + S2 + window_right = next_tockens + S1 - S2 + + if pse is not None: + alibi_slopes = torch.rand(B, N1, dtype=torch.float32) * 0.3 + else: + alibi_slopes = None + + out = flash_attn_func(query, key, value, dropout_p=(1-keep_prob), softmax_scale=scale, causal=causal_switch, + window_size=(window_left, window_right), alibi_slopes=alibi_slopes, deterministic=deterministic) + return out, Const.NONE, Const.NONE diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py index e647312fd..617c96020 100644 --- a/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py @@ -4,7 +4,7 @@ import torch def npu_rms_norm(x, gamma, epsilon=1e-5): rstd = torch.rsqrt(torch.mean(torch.pow(x, 2), axis=-1, keepdim=True) + epsilon) res = x * rstd * gamma - return res.cpu(), rstd.float().cpu() + return res, rstd.float() def npu_rms_norm_backward(grad, x, gamma, rstd): diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py index 0e0fda5f7..9eb910559 100644 --- a/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py @@ -5,7 +5,7 @@ def npu_rotary_mul(x, r1, r2): x1, x2 = torch.chunk(x, 2, -1) x_new = torch.cat((-x2, x1), dim=-1) output = r1 * x + r2 * x_new - return output.cpu() + return output def npu_rotary_mul_backward(dy_tensor, x, r1, r2): diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py index 8717aebaf..8fc92d0e8 100644 --- a/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py @@ -9,7 +9,7 @@ def npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask): x = x - torch.max(x, dim=-1, keepdims=True)[0] x = torch.exp(x.float()) y = torch.div(x, torch.sum(x, dim=-1, keepdims=True)) - return y.to(dtype).cpu() + return y.to(dtype) def npu_scaled_masked_softmax_backward(y_grad, y, mask, scale, fixed_triu_mask): diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py index e03c975a5..e7ffd4179 100644 --- a/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py @@ -14,7 +14,7 @@ def npu_swiglu(x, dim=-1): tensor_out_float = torch.nn.functional.silu(tensor_self_float).type(tensor_dtype).type( torch.float32) * tensor_other_float output_data = tensor_out_float.type(tensor_dtype) - return output_data.cpu() + return output_data def npu_swiglu_backward(grad, x, dim=-1): diff --git a/debug/accuracy_tools/msprobe/pytorch/common/log.py b/debug/accuracy_tools/msprobe/pytorch/common/log.py index cea518fa4..e7df8fdf7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/log.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/log.py @@ -17,16 +17,5 @@ class PyTorchLogger(BaseLogger): current_rank = None return current_rank - def _print_log(self, level, msg, end='\n'): - current_rank = self.get_rank() - current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - pid = os.getpid() - if current_rank is not None: - full_msg = f"{current_time} ({pid}) [rank {current_rank}] [{level}] {msg}" - else: - full_msg = f"{current_time} ({pid}) [{level}] {msg}" - print(full_msg, end=end) - sys.stdout.flush() - logger = PyTorchLogger() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py index 181491488..d42a19d96 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py @@ -14,15 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. """ +import io import logging import os import random import stat +import csv +import json import torch import torch.distributed as dist import numpy as np from functools import wraps from msprobe.core.common.exceptions import DistributedNotInitializedError +from msprobe.core.common.log import logger +from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create, CompareException +from msprobe.core.common.file_check import FileCheckConst, change_mode, FileOpen + try: import torch_npu @@ -31,13 +38,9 @@ except ImportError: else: is_gpu = False -torch_without_guard_version_list = ['2.1', '2.2'] -for version in torch_without_guard_version_list: - if torch.__version__.startswith(version): - torch_without_guard_version = True - break - else: - torch_without_guard_version = False + +torch_without_guard_version = torch.__version__ >= '2.1' + if not is_gpu and not torch_without_guard_version: from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard @@ -245,14 +248,61 @@ def get_tensor_rank(in_feat, out_feat): return tensor_rank -def _create_logger(level=logging.INFO): - logger_ = logging.getLogger() - logger_.setLevel(level) - ch = logging.StreamHandler() - ch.setLevel(level) - logger_.addHandler(ch) - return logger_ +def get_rank_id(): + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + return 0 -log_level = logging.DEBUG if os.environ.get("API_ACCURACY_CHECK_LOG_LEVEL") == "1" else logging.INFO -logger = _create_logger(log_level) +def print_rank_0(message): + if dist.is_initialized(): + if dist.get_rank() == 0: + logger.info(message) + else: + logger.info(message) + + +def load_pt(pt_path, to_cpu=False): + pt_path = os.path.realpath(pt_path) + check_file_or_directory_path(pt_path) + try: + if to_cpu: + pt = torch.load(pt_path, map_location=torch.device("cpu")) + else: + pt = torch.load(pt_path) + except Exception as e: + raise RuntimeError(f"load pt file {pt_path} failed") from e + return pt + + +def save_pt(tensor, filepath): + filepath = os.path.realpath(filepath) + check_path_before_create(filepath) + try: + torch.save(tensor, filepath) + except Exception as e: + logger.error("Save pt file failed, please check according possible error causes: " + "1. out of disk space or disk error, " + "2. no permission to write files, etc.") + raise RuntimeError(f"save pt file {filepath} failed") from e + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + + +def save_api_data(api_data): + """Save data to io stream""" + try: + io_buff = io.BytesIO() + torch.save(api_data, io_buff) + except Exception as e: + raise RuntimeError(f"save api_data to io_buff failed") from e + return io_buff + + +def load_api_data(api_data_bytes): + """Load data from bytes stream""" + try: + buffer = io.BytesIO(api_data_bytes) + buffer = torch.load(buffer, map_location="cpu") + except Exception as e: + raise RuntimeError(f"load api_data from bytes failed") from e + return buffer diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py index caac13958..bdabd0ee7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py @@ -15,63 +15,17 @@ # limitations under the License. """ import os -import sys -import re from msprobe.core.common.utils import CompareException, check_compare_param, \ - check_configuration_param, task_dumppath_get, check_file_or_directory_path, check_regex_prefix_format_valid -from msprobe.pytorch.compare.acc_compare import compare_core + check_configuration_param, task_dumppath_get from msprobe.core.common.file_check import create_directory from msprobe.core.common.exceptions import FileCheckException from msprobe.pytorch.common.log import logger +from msprobe.core.common.const import Const +from msprobe.pytorch.compare.pt_compare import PTComparator +from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): - def check_and_return_dir_contents(dump_dir, prefix): - """ - check the given dump dir and validate files in dump dir by using the given prefix patterns to build a - pattern: ^{prefix}(?:0|[0-9][1-9]*)?$ - - Args: - dump_dir (str): dump dir - prefix (str): prefix for the patterns, prefix should be less than 20 characters and alphanumeric/-/_ only - - Returns: - content [list]: dir contents - Raises: - CompareException: invalid path - ValueError: prefix not match the patterns - - """ - check_regex_prefix_format_valid(prefix) - check_file_or_directory_path(dump_dir, True) - contents = os.listdir(dump_dir) - pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$') - for name in contents: - if not pattern.match(name): - logger.error( - f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump " - f"output. Please check and delete irrelevant files in {dump_dir} and try again." - ) - raise CompareException(CompareException.INVALID_PATH_ERROR) - return contents - - def extract_json(dirname, stack_json=False): - json_path = '' - for fname in os.listdir(dirname): - full_path = os.path.join(dirname, fname) - if full_path.endswith('.json'): - json_path = full_path - if not stack_json and 'stack' not in json_path: - break - if stack_json and 'stack' in json_path: - break - - # Provide robustness on invalid directory inputs - if not json_path: - logger.error(f'No file is found in dump dir {dirname}. ') - raise CompareException(CompareException.NO_DUMP_FILE_ERROR) - return json_path - if kwargs.get('suffix'): logger.error("Argument 'suffix' is not supported for compare_distributed.") raise CompareException(CompareException.INVALID_PARAM_ERROR) @@ -89,14 +43,14 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): for nr, br in zip(npu_ranks, bench_ranks): npu_data_dir = os.path.join(npu_dump_dir, nr) bench_data_dir = os.path.join(bench_dump_dir, br) - npu_json_path = extract_json(npu_data_dir, stack_json=False) - bench_json_path = extract_json(bench_data_dir, stack_json=False) - stack_json_path = extract_json(npu_data_dir, stack_json=True) + npu_path = extract_json(npu_data_dir, stack_json=False) + bench_path = extract_json(bench_data_dir, stack_json=False) + stack_path = extract_json(npu_data_dir, stack_json=True) dump_result_param = { - 'npu_json_path': npu_json_path, - 'bench_json_path': bench_json_path, - 'stack_json_path': stack_json_path, + 'npu_json_path': npu_path, + 'bench_json_path': bench_path, + 'stack_json_path': stack_path, 'is_print_compare_log': True } try: @@ -106,6 +60,7 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare) except (CompareException, FileCheckException) as error: logger.error('Compare failed. Please check the arguments and do it again!') - sys.exit(error.code) - compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare, + raise CompareException(error.code) from error + pt_comparator = PTComparator() + pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare, md5_compare=md5_compare, **kwargs) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/match.py b/debug/accuracy_tools/msprobe/pytorch/compare/match.py index 6347d8887..ac445ad8e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/match.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/match.py @@ -1,16 +1,13 @@ import os -import yaml -from msprobe.core.common.file_check import FileOpen -from msprobe.core.common.utils import CompareException +from msprobe.core.common.utils import CompareException, load_yaml class AtenIrMapping(): def __init__(self): cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "mapping.yaml") - with FileOpen(yaml_path, 'r') as f: - self.aten_mapping = yaml.safe_load(f) - + self.aten_mapping = load_yaml(yaml_path) + def match(self, op1, op2): if "Aten" in op1 and "Aten" not in op2: return self.match_op(op1, op2) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py new file mode 100644 index 000000000..c82f23de4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py @@ -0,0 +1,51 @@ +import os.path +import torch +from msprobe.core.common.const import FileCheckConst, Const +from msprobe.core.common.file_check import create_directory +from msprobe.pytorch.common.log import logger +from msprobe.core.common.exceptions import FileCheckException +from msprobe.core.compare.acc_compare import Comparator +from msprobe.core.common.utils import check_configuration_param, task_dumppath_get, check_compare_param, FileChecker +from msprobe.core.common.utils import CompareException +from msprobe.pytorch.common.utils import load_pt + + +class PTComparator (Comparator): + def __init__(self): + self.frame_name = PTComparator.__name__ + + def read_npy_data(self, dir_path, file_name): + data_path = os.path.join(dir_path, file_name) + path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, + FileCheckConst.PT_SUFFIX, False) + data_path = path_checker.common_check() + try: + data_value = load_pt(data_path, + to_cpu=True).detach() # detach because numpy can not process gradient information + except RuntimeError as e: + # 这里捕获 load_pt 中抛出的异常 + logger.error(f"Failed to load the .pt file at {data_path}.") + raise CompareException(CompareException.INVALID_FILE_ERROR) from e + except AttributeError as e: + # 这里捕获 detach 方法抛出的异常 + logger.error(f"Failed to detach the loaded tensor.") + raise CompareException(CompareException.DETACH_ERROR) from e + if data_value.dtype == torch.bfloat16: + data_value = data_value.to(torch.float32) + data_value = data_value.numpy() + return data_value + + +def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False): + try: + summary_compare, md5_compare = task_dumppath_get(input_param) + check_configuration_param(stack_mode, auto_analyze, fuzzy_match) + create_directory(output_path) + check_compare_param(input_param, output_path, summary_compare, md5_compare) + except (CompareException, FileCheckException) as error: + logger.error('Compare failed. Please check the arguments and do it again!') + raise CompareException(error.code) from error + pt_comparator = PTComparator() + pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode, + auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare, + md5_compare=md5_compare) diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py index 851a61d04..7c32be7cc 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py @@ -13,7 +13,6 @@ class DebuggerConfig: self.seed = common_config.seed if common_config.seed else 1234 self.is_deterministic = common_config.is_deterministic self.enable_dataloader = common_config.enable_dataloader - self.enable_step_auto_dump = common_config.enable_step_auto_dump self.scope = task_config.scope if task_config.scope else [] self.list = task_config.list if task_config.list else [] self.data_mode = task_config.data_mode if task_config.data_mode else ["all"] @@ -36,7 +35,16 @@ class DebuggerConfig: "preheat_step": task_config.preheat_step if task_config.preheat_step else 15, "max_sample": task_config.max_sample if task_config.max_sample else 20, } - + + self.online_run_ut = False + if self.task == Const.TENSOR: + # dump api tensor and collaborate with online run_ut + self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False + self.nfs_path = task_config.nfs_path if task_config.nfs_path else "" + self.tls_path = task_config.tls_path if task_config.tls_path else "" + self.host = task_config.host if task_config.host else "" + self.port = task_config.port if task_config.port else -1 + self.check() if self.step: self.step.sort() diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index 2d5668009..8433f0af6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -5,10 +5,13 @@ from msprobe.pytorch.service import Service from msprobe.pytorch.common.log import logger from msprobe.pytorch.pt_config import parse_json_config from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common.const import Const +from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor class PrecisionDebugger: _instance = None + tasks_not_need_debugger = [Const.GRAD_PROBE] def __new__(cls, *args, **kwargs): if cls._instance is None: @@ -25,17 +28,18 @@ class PrecisionDebugger: level=None, model=None, step=None, - enable_step_auto_dump=None ): if not hasattr(self, "initialized"): self.api_origin = False self.initialized = True self.model = self.check_model_valid(model) common_config, task_config = parse_json_config(config_path, task) + self.task = common_config.task + if self.task == Const.GRAD_PROBE: + self.gm = GradientMonitor(common_config, task_config) + return if step: common_config.step = step - if enable_step_auto_dump: - common_config.enable_step_auto_dump = enable_step_auto_dump self.config = DebuggerConfig( common_config, task_config, task, dump_path, level ) @@ -45,9 +49,6 @@ class PrecisionDebugger: if self.enable_dataloader: logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.") dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__) - self.enable_step_auto_dump = self.config.enable_step_auto_dump - if self.enable_step_auto_dump: - self.start_for_optimizer() @property def instance(self): @@ -61,30 +62,15 @@ class PrecisionDebugger: MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。" ) - # 非侵入式dump使能方法 - @classmethod - def start_for_optimizer(cls): - instance = cls._instance - if not instance: - raise Exception("No instance of PrecisionDebugger found.") - elif torch.__version__ < '2.0.0': - raise Exception("Pytorch version is earlier than 2.0.0 does not support optimizer hooks. \ - Please turn off enable_step_auto_dump and use the start, stop and step methods of PrecisionDebugger instead.") - else: - logger.info_on_rank_0("The enable_step_auto_dump is on and start()/stop()/step() will not take effect.") - logger.warning_on_rank_0("Customized optimizer iteration is not supported. Please use start, stop and step methods when using customized optimizer.") - instance.service.hook_optimizer(instance.model) - instance.service.start(instance.model) - @classmethod def start(cls): instance = cls._instance + if instance.task in PrecisionDebugger.tasks_not_need_debugger: + return if not instance: raise Exception("No instance of PrecisionDebugger found.") if instance.enable_dataloader: logger.warning_on_rank_0("DataLoader is enabled, start() skipped.") - elif instance.enable_step_auto_dump: - logger.warning_on_rank_0("optimizer is enabled, start() skipped.") else: instance.service.start(instance.model, instance.api_origin) instance.api_origin = False @@ -93,37 +79,37 @@ class PrecisionDebugger: @classmethod def forward_backward_dump_end(cls): instance = cls._instance - if not instance: - raise Exception("PrecisionDebugger instance is not created.") - if instance.enable_dataloader: - logger.warning_on_rank_0("DataLoader is enabled, forward_backward_dump_end() skipped.") - elif instance.enable_step_auto_dump: - logger.warning_on_rank_0("optimizer is enabled, forward_backward_dump_end() skipped.") - else: - instance.service.forward_backward_dump_end() - instance.api_origin = True + instance.service.forward_backward_dump_end() + instance.api_origin = True @classmethod def stop(cls): instance = cls._instance + if instance.task in PrecisionDebugger.tasks_not_need_debugger: + return if not instance: raise Exception("PrecisionDebugger instance is not created.") if instance.enable_dataloader: logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.") - elif instance.enable_step_auto_dump: - logger.warning_on_rank_0("optimizer is enabled, stop() skipped.") else: instance.service.stop() @classmethod def step(cls): - instance = cls._instance - if not instance: + if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger: + return + if not cls._instance: raise Exception("PrecisionDebugger instance is not created.") - elif instance.enable_step_auto_dump: - logger.warning_on_rank_0("optimizer is enabled, step() skipped.") - else: - instance.service.step() + cls._instance.service.step() + + @classmethod + def monitor(cls, model): + if not cls._instance: + raise Exception("PrecisionDebugger instance is not created.") + if cls._instance.task != Const.GRAD_PROBE: + return + cls._instance.gm.monitor(model) + def iter_tracer(func): def func_wrapper(*args, **kwargs): diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/__init__.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/__init__.py index d234898c0..f0b40ccb2 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/__init__.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/__init__.py @@ -1,4 +1,4 @@ -from msprobe.core.common.log import logger +from msprobe.pytorch.common.log import logger from msprobe.core.common.exceptions import FreeBenchmarkException from msprobe.core.common.const import Const diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py index 6781a1c2f..e58223e59 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py @@ -2,7 +2,7 @@ import torch from msprobe.core.common.exceptions import FreeBenchmarkException from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.constant import CommonField -from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams +from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams, data_pre_deal from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import ( FuzzHandlerFactory, @@ -16,7 +16,6 @@ class GradSaver: self.handler_params = handler_params self.api_name = handler_params.api_name self.origin_func = origin_func - self.data_params = DataParams() self.is_compare = True self.kwargs = dict() self.perturbed_grad_input = tuple() @@ -61,28 +60,25 @@ class GradSaver: _index += 1 def compare_grad_results(self, handler, origin_grad, perturbed_grad, index): - # TODO get dtype? - self.data_params.original_result = origin_grad - self.data_params.perturbed_result = perturbed_grad - self.data_params.grad_unequal_flag = False - self.data_params.valid_input_index = index + data_params = DataParams() + data_params.original_result = origin_grad + data_params.perturbed_result = perturbed_grad + data_params.grad_unequal_flag = False + data_params.valid_input_index = index try: - handler.handle(self.data_params) - if not self.data_params.is_consistent: + handler.handle(data_params) + if not data_params.is_consistent: self.is_compare = False - self.data_params.grad_unequal_flag = True - self.data_params.is_consistent = True - self.data_params.perturbed_result = self.perturbed_grad_input - self.data_params.original_result = self.origin_grad_input - handler.handle(self.data_params) + data_params.grad_unequal_flag = True + data_params.is_consistent = True + data_params.perturbed_result = self.perturbed_grad_input + data_params.original_result = self.origin_grad_input + handler.handle(data_params) except Exception as e: logger.warning_on_rank_0( f"[msprobe] Free benchmark: compare two vjp failed: api:{self.handler_params.api_name}." f"{e}" ) - # 在扰动前后输出对比后释放输出的引用 - self.data_params.perturbed_result = None - self.data_params.original_result = None def check_grad_input(self, origin_grad, new_grad_index): if self.perturbed_grad_input is None: @@ -164,20 +160,20 @@ class GradSaver: return grad_input def calculate_perturbed_grad_input(self, grad_output, need_grad_tensors, inner_args): - self.data_params.args = [need_grad_tensors, grad_output, inner_args] - self.data_params.kwargs = {} - self.data_params.valid_input_index = 0 - self.data_params.origin_func = self.get_grad_input_from_vjp + data_params = data_pre_deal( + self.handler_params.api_name, + self.get_grad_input_from_vjp, + [need_grad_tensors, grad_output, inner_args], + {} + ) layer = LayerFactory.create( self.handler_params.api_name, self.handler_params.fuzz_device, self.handler_params.pert_mode, ) - layer.handle(self.data_params) - # 在计算扰动输出之后,释放输入的引用 - self.data_params.args = None + layer.handle(data_params) # 确定扰动成功后,才会暂存 - if self.data_params.perturbed_result: + if data_params.perturbed_result: self.perturbed_grad_input = tuple( - [x.cpu() for x in self.data_params.perturbed_result] + [x.cpu() for x in data_params.perturbed_result] ) diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py index 971776d13..69ece0a0c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py @@ -10,7 +10,10 @@ from msprobe.pytorch.free_benchmark.common.enums import ( HandlerType, PerturbationMode, ) -from msprobe.pytorch.free_benchmark.common.params import data_pre_deal, make_handler_params +from msprobe.pytorch.free_benchmark.common.params import ( + data_pre_deal, + make_handler_params, +) from msprobe.pytorch.free_benchmark.compare.grad_saver import GradSaver from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import ( @@ -70,9 +73,9 @@ class FreeBenchmarkCheck(ABC): layer.handle(data_params) handler_params = make_handler_params(name, self.config, self.current_iter) handler = FuzzHandlerFactory.create(handler_params) - handler.handle(data_params) - return data_params.perturbed_result, handler.get_unequal_rows() - + perturbed_output = handler.handle(data_params) + return perturbed_output, handler.get_unequal_rows() + def backward(self, name, module, grad_output): if not self.config.fuzz_stage == Const.BACKWARD: diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py index a18ef1c51..2ccc2bfcf 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py @@ -32,7 +32,7 @@ class AddNoiseLayer(NpuBaseLayer): return type(tensor_obj)([self.add_noise(value) for value in tensor_obj]) return tensor_obj - def handle(self, params: DataParams) -> torch.Any: + def handle(self, params: DataParams): """ 对输入添加扰动并返回 """ diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py index 45dea7b93..a0ac21691 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py @@ -48,7 +48,7 @@ class BitNoiseLayer(NpuBaseLayer): return type(tensor_obj)([self.add_bit_noise(value) for value in tensor_obj]) return tensor_obj - def handle(self, params: DataParams) -> torch.Any: + def handle(self, params: DataParams): """ 对输入添加扰动并返回 """ diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py index 91085d57a..ae5bf9f03 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py @@ -39,7 +39,7 @@ class ChangeValueLayer(NpuBaseLayer): return type(tensor_obj)([self.change_value(value) for value in tensor_obj]) return tensor_obj - def handle(self, params: DataParams) -> torch.Any: + def handle(self, params: DataParams): """ 对输入添加扰动并返回 """ diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py index ad6d8b898..b5a106dac 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py @@ -17,7 +17,7 @@ class ImprovePrecisionLayer(NpuBaseLayer): and torch.is_floating_point(tensor_obj) and tensor_obj.dtype not in [torch.float32, torch.float64] ): - self._set_improve_valus(tensor_obj) + self._set_improve_values(tensor_obj) tensor_obj = self._change_dtype(tensor_obj) self.is_added = True return tensor_obj @@ -32,7 +32,7 @@ class ImprovePrecisionLayer(NpuBaseLayer): ) return tensor_obj - def handle(self, params: DataParams) -> torch.Any: + def handle(self, params: DataParams): logger.info_on_rank_0( f"[msprobe] Free benchmark: Perturbation is " f"{PerturbationMode.IMPROVE_PRECISION} of {self.api_name}." @@ -50,7 +50,7 @@ class ImprovePrecisionLayer(NpuBaseLayer): params.perturbed_result = params.origin_func(*new_args, **new_kwargs) return params.perturbed_result - def _set_improve_valus(self, inputs): + def _set_improve_values(self, inputs): if inputs.dtype in [torch.float16, torch.bfloat16]: self.perturbed_value = torch.float32 diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py index a69c56002..fa775e00e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py @@ -16,7 +16,7 @@ class NoChangeLayer(NpuBaseLayer): self.is_added = True return tensor_obj - def handle(self, params: DataParams) -> torch.Any: + def handle(self, params: DataParams): """ 对输入添加扰动并返回 """ diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py index d34ac9765..376f4ee3e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py @@ -8,7 +8,7 @@ from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer class CpuLayer(BaseLayer): - def handle(self, params: DataParams) -> torch.Any: + def handle(self, params: DataParams): logger.info_on_rank_0( f"[msprobe] Free benchmark: Perturbation is to_cpu of {self.api_name}." diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py index 5ee968c6a..46efd8283 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py @@ -22,7 +22,6 @@ class FuzzHandlerFactory: handler = FuzzHandlerFactory.result_handlers.get(params.handler_type) else: handler = FuzzHandlerFactory.result_handlers.get(HandlerType.PREHEAT) - # TODO if not handler: raise FreeBenchmarkException( FreeBenchmarkException.UnsupportedType, diff --git a/debug/accuracy_tools/msprobe/pytorch/function_factory.py b/debug/accuracy_tools/msprobe/pytorch/function_factory.py index c2fd8bfd0..5bdb20584 100644 --- a/debug/accuracy_tools/msprobe/pytorch/function_factory.py +++ b/debug/accuracy_tools/msprobe/pytorch/function_factory.py @@ -2,11 +2,12 @@ from msprobe.pytorch.common.utils import logger from msprobe.pytorch.bench_functions.apply_adam_w import npu_apply_adam_w from msprobe.pytorch.bench_functions.confusion_transpose import npu_confusion_transpose, \ npu_confusion_transpose_backward -from msprobe.pytorch.bench_functions.fast_gelu import fast_gelu, npu_fast_gelu_backward +from msprobe.pytorch.bench_functions.fast_gelu import npu_fast_gelu, npu_fast_gelu_backward from msprobe.pytorch.bench_functions.layer_norm_eval import npu_layer_norm_eval from msprobe.pytorch.bench_functions.linear import npu_linear, npu_linear_backward from msprobe.pytorch.bench_functions.matmul_backward import matmul_backward -from msprobe.pytorch.bench_functions.npu_fusion_attention import npu_fusion_attention, npu_fusion_attention_grad +from msprobe.pytorch.bench_functions.npu_fusion_attention import npu_fusion_attention, npu_fusion_attention_grad, \ + gpu_fusion_attention from msprobe.pytorch.bench_functions.rms_norm import npu_rms_norm, npu_rms_norm_backward from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotary_mul_backward from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \ @@ -62,8 +63,8 @@ class Register(dict): # register for npu custom bench functions npu_custom_functions = Register() npu_custom_functions([ - npu_apply_adam_w, npu_confusion_transpose, fast_gelu, npu_layer_norm_eval, npu_linear, npu_fusion_attention, - npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu + npu_apply_adam_w, npu_confusion_transpose, npu_fast_gelu, npu_layer_norm_eval, npu_linear, npu_fusion_attention, + npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention ]) # register for npu custom backward bench functions diff --git a/debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py b/debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py index efb95c336..5d2e8d985 100644 --- a/debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py +++ b/debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py @@ -24,7 +24,7 @@ def module_dump(module, dump_name): dump_name = dump_name + Const.SEP + str(module_count.get(dump_name)) + Const.SEP pdg = PrecisionDebugger() - _, forward_hook, backward_hook = pdg.service.build_hook(BaseScope.Module_Type_Module, dump_name) + _, forward_hook, backward_hook, _ = pdg.service.build_hook(BaseScope.Module_Type_Module, dump_name) module.register_forward_hook(forward_hook, with_kwargs=True) module.register_full_backward_hook(backward_hook) diff --git a/debug/accuracy_tools/msprobe/pytorch/grad_probe/__init__.py b/debug/accuracy_tools/msprobe/pytorch/grad_probe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py new file mode 100644 index 000000000..b577b2fab --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py @@ -0,0 +1,92 @@ +import os +from collections import defaultdict + +import torch +if int(torch.__version__.split('.')[0]) >= 2: + from torch.optim.optimizer import register_optimizer_step_pre_hook +from msprobe.pytorch.grad_probe.grad_stat_csv import GradStatCsv +from msprobe.core.grad_probe.utils import check_numeral_list_ascend, data_in_list_target +from msprobe.core.grad_probe.constant import GradConst, level_adp +from msprobe.core.common.file_check import create_directory +from msprobe.pytorch.common.log import logger +from msprobe.core.common.utils import remove_path, write_csv, save_npy +from msprobe.pytorch.common.utils import get_rank_id, print_rank_0, save_pt + + +class GradientMonitor: + + def __init__(self, common_config, task_config): + level = task_config.grad_level + if level not in level_adp: + raise Exception(f"level is valid, not in {level_adp.keys()}") + self._level_adp = level_adp[level] + self._param_list = task_config.param_list + self._target_ranks = common_config.rank + logger.info(f"target rank {self._target_ranks}") + self._target_step = common_config.step + logger.info(f"target step {self._target_step}") + self._bounds = task_config.bounds + check_numeral_list_ascend(self._bounds) + self._output_path = common_config.dump_path + if not os.path.exists(self._output_path): + create_directory(self._output_path) + else: + logger.warning(f"the file in {self._output_path} will be recoverd") + self._step = -1 + self._param2name = defaultdict(str) + + @property + def output_path(self): + return self._output_path + + @staticmethod + def save_grad_direction(param_name, grad, save_path): + if not os.path.exists(save_path): + create_directory(save_path) + param_grad = grad.clone().detach() + is_positive = param_grad > 0 + save_filepath = os.path.join(save_path, f"{param_name}.npy") + save_npy(is_positive.numpy(), save_filepath) + + def monitor(self, model): + print_rank_0("> parameter names:") + for name, param in model.named_parameters(): + self._param2name[param] = name + print_rank_0(f"\t{name}") + setattr(self, "_rank", get_rank_id()) + if torch.distributed.is_initialized() and not data_in_list_target(getattr(self, "_rank"), self._target_ranks): + return + self._hook_optimizer() + + def _hook_optimizer(self): + def optimizer_pre_step_hook(optimizer, args, kargs): + self._step += 1 + logger.info(f"grad_probe: optimizer step {self._step}") + if not data_in_list_target(self._step, self._target_step): + return + output_lines = [] + for param, param_name in self._param2name.items(): + if not data_in_list_target(param_name, self._param_list): + continue + grad = param.main_grad if hasattr(param, "main_grad") else param.grad + if grad is None: + logger.info(f"grad is None: {param_name}") + continue + grad_info = GradStatCsv.generate_csv_line(param_name, self._level_adp, grad, self._bounds) + output_lines.append(grad_info) + if self._level_adp["have_grad_direction"]: + GradientMonitor.save_grad_direction(param_name, grad, + f'{self._output_path}/rank{self._rank}/step{self._step}') + output_dirpath = os.path.join(self._output_path, f"rank{getattr(self, '_rank')}") + if not os.path.isdir(output_dirpath): + create_directory(output_dirpath) + output_path = os.path.join(output_dirpath, f"grad_summary_{self._step}.csv") + if os.path.exists(output_path): + logger.warning(f"{output_path} will be recoverd") + remove_path(output_path) + header_result = GradStatCsv.generate_csv_header(self._level_adp, self._bounds) + output_lines.insert(0, header_result) + write_csv(output_lines, output_path) + logger.info(f"write grad data to {output_path}") + if int(torch.__version__.split('.')[0]) >= 2: + register_optimizer_step_pre_hook(optimizer_pre_step_hook) diff --git a/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py new file mode 100644 index 000000000..757a1aebf --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py @@ -0,0 +1,129 @@ +from abc import ABC, abstractmethod +from collections import namedtuple +import hashlib +import torch +from msprobe.core.grad_probe.constant import GradConst + +CSV_header_input = namedtuple("CSV_header_input", ["bounds"]) +CSV_content_input = namedtuple("CSV_content_input", ["grad", "bounds"]) + + +class GradStatCsv: + csv = {} + + @staticmethod + def generate_csv_header(level, bounds): + header = ["param_name"] + for key in level["header"]: + csv_header_input = CSV_header_input(bounds=bounds) + header.extend(GradStatCsv.csv[key].generate_csv_header(csv_header_input)) + return header + + @staticmethod + def generate_csv_line(param_name, level, grad, bounds): + line = [param_name] + for key in level["header"]: + csv_content_input = CSV_content_input(grad=grad, bounds=bounds) + line.extend(GradStatCsv.csv[key].generate_csv_content(csv_content_input)) + return line + + +def register_csv_item(key, cls=None): + if cls is None: + # 无参数时,返回装饰器函数 + return lambda cls: register_csv_item(key, cls) + GradStatCsv.csv[key] = cls + return cls + + +class CsvItem(ABC): + @abstractmethod + def generate_csv_header(csv_header_input): + pass + + @abstractmethod + def generate_csv_content(csv_content_input): + pass + + +@register_csv_item(GradConst.MD5) +class CSV_md5(CsvItem): + def generate_csv_header(csv_header_input): + return ["MD5"] + + def generate_csv_content(csv_content_input): + grad = csv_content_input.grad + tensor_bytes = grad.cpu().detach().float().numpy().tobytes() + md5_hash = hashlib.md5(tensor_bytes) + return [md5_hash.hexdigest()] + + +@register_csv_item(GradConst.DISTRIBUTION) +class CSV_distribution(CsvItem): + def generate_csv_header(csv_header_input): + bounds = csv_header_input.bounds + intervals = [] + if bounds: + intervals.append(f"(-inf, {bounds[0]}]") + for i in range(1, len(bounds)): + intervals.append(f"({bounds[i-1]}, {bounds[i]}]") + if intervals: + intervals.append(f"({bounds[-1]}, inf)") + intervals.append("=0") + + return intervals + + def generate_csv_content(csv_content_input): + grad = csv_content_input.grad + bounds = csv_content_input.bounds + grad = grad.cpu().detach() + if grad.dtype == torch.bfloat16: + grad = grad.to(torch.float32) + element_num = grad.numel() + grad_equal_0_num = (grad == 0).sum().item() + bound = torch.Tensor(bounds) + bucketsize_result = torch.bucketize(grad, bound) + interval_nums = [(bucketsize_result == i).sum().item() for i in range(len(bound) + 1)] + interval_nums.append(grad_equal_0_num) + return_list = [x / element_num if element_num != 0 else 0 for x in interval_nums] + return return_list + + +@register_csv_item(GradConst.MAX) +class CSV_max(CsvItem): + def generate_csv_header(csv_header_input): + return ["max"] + + def generate_csv_content(csv_content_input): + grad = csv_content_input.grad + return [torch.max(grad).cpu().detach().float().numpy().tolist()] + + +@register_csv_item(GradConst.MIN) +class CSV_max(CsvItem): + def generate_csv_header(csv_header_input): + return ["min"] + + def generate_csv_content(csv_content_input): + grad = csv_content_input.grad + return [torch.min(grad).cpu().detach().float().numpy().tolist()] + + +@register_csv_item(GradConst.NORM) +class CSV_max(CsvItem): + def generate_csv_header(csv_header_input): + return ["norm"] + + def generate_csv_content(csv_content_input): + grad = csv_content_input.grad + return [torch.norm(grad).cpu().detach().float().numpy().tolist()] + + +@register_csv_item(GradConst.SHAPE) +class CSV_shape(CsvItem): + def generate_csv_header(csv_header_input): + return ["shape"] + + def generate_csv_content(csv_content_input): + grad = csv_content_input.grad + return [list(grad.shape)] \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py index ff6427e51..aa724b50f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py @@ -23,6 +23,7 @@ import torch.nn as nn import torch.utils.hooks as full_hooks from msprobe.core.common.const import Const +torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' class HOOKModule(nn.Module): @@ -48,9 +49,13 @@ class HOOKModule(nn.Module): else: HOOKModule.module_count[self.prefix] += 1 self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + Const.SEP - forward_pre_hook, forward_hook, backward_hook = build_hook(self.prefix) - self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) - self.register_forward_hook(forward_hook, with_kwargs=True) + forward_pre_hook, forward_hook, backward_hook, _ = build_hook(self.prefix) + if torch_version_above_or_equal_2: + self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) + self.register_forward_hook(forward_hook, with_kwargs=True) + else: + self.register_forward_pre_hook(forward_pre_hook) + self.register_forward_hook(forward_hook) self.register_backward_hook(backward_hook) def __call__(self, *input, **kwargs): diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml index f68708e94..11281396b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml @@ -1874,4 +1874,6 @@ distributed: - _reduce_scatter_base - _all_gather_base - all_to_all_single - - all_to_all \ No newline at end of file + - all_to_all + - all_gather_into_tensor + - reduce_scatter_tensor \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py index c1e581675..d2370b6d0 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py @@ -16,14 +16,15 @@ """ import os -import yaml +from msprobe.core.common.utils import load_yaml -from msprobe.core.common.file_check import FileOpen -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with FileOpen(yaml_path, 'r') as f: - Ops = yaml.safe_load(f) - WrapFunctionalOps = Ops.get('functional') - WrapTensorOps = Ops.get('tensor') - WrapTorchOps = Ops.get('torch') +def get_ops(): + cur_path = os.path.dirname(os.path.realpath(__file__)) + yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") + ops = load_yaml(yaml_path) + wrap_functional = ops.get('functional') + wrap_tensor = ops.get('tensor') + wrap_torch = ops.get('torch') + wrap_npu_ops = ops.get('torch_npu') + return set(wrap_functional) | set(wrap_tensor) | set(wrap_torch) | set(wrap_npu_ops) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py index a02abbe5f..b63bd9d02 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py @@ -18,20 +18,17 @@ import os import torch -import yaml - from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.common.utils import torch_device_guard from msprobe.core.common.const import Const -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.utils import load_yaml from msprobe.pytorch.function_factory import npu_custom_grad_functions cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with FileOpen(yaml_path, 'r') as f: - Ops = yaml.safe_load(f) - WrapAtenOps = Ops.get('aten') - WhiteAtenOps = Ops.get('white_aten_ops', []) +ops = load_yaml(yaml_path) +wrap_aten_ops = ops.get('aten') +white_aten_ops = ops.get('white_aten_ops', []) aten_func = {} @@ -40,9 +37,9 @@ for f in dir(torch.ops.aten): def get_aten_ops(): - global WrapAtenOps + global wrap_aten_ops _all_aten_ops = dir(torch.ops.aten) - return set(WrapAtenOps) & set(_all_aten_ops) + return set(wrap_aten_ops) & set(_all_aten_ops) class HOOKAtenOP(object): @@ -69,7 +66,7 @@ class AtenOPTemplate(HOOKModule): if isinstance(self.op, str): if self.op in npu_custom_grad_functions: return npu_custom_grad_functions[self.op](*args, **kwargs) - if self.op in WhiteAtenOps: + if self.op in white_aten_ops: return eval(f"torch.ops.aten.{self.op}")(*args, **kwargs) if self.op not in aten_func: raise Exception(f"Skip op[{self.op}] accuracy check, because the op is not " diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py index 6cf425441..1f720a32b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py @@ -18,18 +18,15 @@ import os from functools import wraps import torch.distributed as dist -import yaml from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.common.utils import torch_device_guard from msprobe.core.common.const import Const -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.utils import load_yaml cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with FileOpen(yaml_path, 'r') as f: - WrapDistributedOps = yaml.safe_load(f).get('distributed') distributed_func = {} @@ -38,9 +35,10 @@ for f in dir(dist): def get_distributed_ops(): - global WrapDistributedOps _all_distributed_ops = dir(dist) - return set(WrapDistributedOps) & set(_all_distributed_ops) + yaml_data = load_yaml(yaml_path) + wrap_distributed_ops = yaml_data.get('distributed') + return set(wrap_distributed_ops) & set(_all_distributed_ops) class HOOKDistributedOP(object): @@ -57,7 +55,12 @@ class DistributedOPTemplate(HOOKModule): @torch_device_guard def forward(self, *args, **kwargs): - return distributed_func.get(self.op_name_)(*args, **kwargs) + if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]: + handle = distributed_func.get(self.op_name_)(*args, **kwargs) + handle.wait() + return handle + else: + return distributed_func.get(self.op_name_)(*args, **kwargs) def wrap_distributed_op(op_name, hook): diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py index fd7610ca8..95715ec1a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py @@ -16,15 +16,13 @@ """ import os - import torch -import yaml from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.common.utils import torch_device_guard from msprobe.core.common.const import Const from msprobe.pytorch.common.log import logger -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.utils import load_yaml def remove_dropout(): @@ -66,14 +64,13 @@ def remove_dropout(): cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with FileOpen(yaml_path, 'r') as f: - WrapFunctionalOps = yaml.safe_load(f).get('functional') def get_functional_ops(): - global WrapFunctionalOps + yaml_data = load_yaml(yaml_path) + wrap_functional_ops = yaml_data.get('functional') _all_functional_ops = dir(torch.nn.functional) - return set(WrapFunctionalOps) & set(_all_functional_ops) + return set(wrap_functional_ops) & set(_all_functional_ops) TorchFunctions = {func: getattr(torch.nn.functional, func) for func in get_functional_ops()} diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py index 8a67ed942..26add2b4a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py @@ -17,18 +17,16 @@ import os import torch -import yaml from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version from msprobe.core.common.const import Const -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.utils import load_yaml from msprobe.pytorch.function_factory import npu_custom_functions cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with FileOpen(yaml_path, 'r') as f: - WrapNpuOps = yaml.safe_load(f).get('torch_npu') + try: import torch_npu @@ -38,13 +36,19 @@ else: is_gpu = False +cuda_func_mapping = { + "npu_fusion_attention" : "gpu_fusion_attention" +} + + def get_npu_ops(): - global WrapNpuOps if torch_without_guard_version: _npu_ops = dir(torch.ops.npu) else: _npu_ops = dir(torch_npu._C._VariableFunctionsClass) - return set(WrapNpuOps) & set(_npu_ops) + yaml_data = load_yaml(yaml_path) + wrap_npu_ops = yaml_data.get('torch_npu') + return set(wrap_npu_ops) & set(_npu_ops) class HOOKNpuOP(object): @@ -53,10 +57,11 @@ class HOOKNpuOP(object): class NpuOPTemplate(HOOKModule): - def __init__(self, op_name, hook, need_hook=True): + def __init__(self, op_name, hook, need_hook=True, device=Const.CPU_LOWERCASE): self.op_name_ = op_name self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP self.need_hook = need_hook + self.device = device if need_hook: super().__init__(hook) @@ -65,7 +70,10 @@ class NpuOPTemplate(HOOKModule): if not self.need_hook: if self.op_name_ not in npu_custom_functions: raise Exception(f'There is not bench function {self.op_name_}') - return npu_custom_functions[self.op_name_](*args, **kwargs) + if self.device == Const.CUDA_LOWERCASE: + self.op_name_ = cuda_func_mapping.get(self.op_name_, self.op_name_) + if self.device in [Const.CUDA_LOWERCASE, Const.CPU_LOWERCASE]: + return npu_custom_functions[self.op_name_](*args, **kwargs) if torch_without_guard_version: return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs) else: diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py index 3e26ae3be..90bb3c61b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py @@ -18,23 +18,22 @@ import os import torch -import yaml from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.common.utils import torch_device_guard, parameter_adapter from msprobe.core.common.const import Const -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.utils import load_yaml + cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with FileOpen(yaml_path, 'r') as f: - WrapTensorOps = yaml.safe_load(f).get('tensor') def get_tensor_ops(): - global WrapTensorOps _tensor_ops = dir(torch.Tensor) - return set(WrapTensorOps) & set(_tensor_ops) + yaml_data = load_yaml(yaml_path) + wrap_tensor_ops = yaml_data.get('tensor') + return set(wrap_tensor_ops) & set(_tensor_ops) TensorOps = {op: getattr(torch.Tensor, op) for op in get_tensor_ops()} diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py index 486ddda49..32d086a6d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py @@ -16,25 +16,23 @@ """ import os - import torch -import yaml from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.common.utils import torch_device_guard from msprobe.core.common.const import Const -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.utils import load_yaml + cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with FileOpen(yaml_path, 'r') as f: - WrapTorchOps = yaml.safe_load(f).get('torch') def get_torch_ops(): - global WrapTorchOps _torch_ops = [] - for operation in WrapTorchOps: + yaml_data = load_yaml(yaml_path) + wrap_torch_ops = yaml_data.get('torch') + for operation in wrap_torch_ops: if '.' in operation: operation_sub_module_name, operation_sub_op = operation.rsplit('.', 1) operation_sub_module = getattr(torch, operation_sub_module_name) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py index d78beb2a6..022535824 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py @@ -16,24 +16,22 @@ """ import os - import torch -import yaml +from msprobe.core.common.const import Const +from msprobe.core.common.utils import load_yaml from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.core.common.file_check import FileOpen from msprobe.pytorch.common.utils import torch_device_guard -from msprobe.core.common.const import Const + cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with FileOpen(yaml_path, 'r') as f: - WrapVfOps = yaml.safe_load(f).get('_VF') def get_vf_ops(): - global WrapVfOps - return WrapVfOps + yaml_data = load_yaml(yaml_path) + wrap_vf_ops = yaml_data.get('_VF') + return wrap_vf_ops class HOOKVfOP(object): diff --git a/debug/accuracy_tools/msprobe/pytorch/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/module_processer.py index 3e9969d32..e6d2125e4 100644 --- a/debug/accuracy_tools/msprobe/pytorch/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/module_processer.py @@ -5,6 +5,7 @@ from torch.utils.hooks import BackwardHook from msprobe.core.common.const import Const from msprobe.core.data_dump.scope import ModuleRangeScope +torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' class ModuleProcesser: @@ -109,7 +110,29 @@ class ModuleProcesser: if self.scope: self.scope.end_module(module.mindstudio_reserved_name) - if Const.START in start_or_stop: - return pre_hook + def backward_hook(module, input, output=None): + try: + index = ModuleProcesser.module_count_func(name_prefix) + except IndexError as e: + index = None + pass + module.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index) + forward_full_name = full_name.replace(Const.BACKWARD, Const.FORWARD) + ModuleProcesser.module_node[full_name] = ModuleProcesser.module_node[forward_full_name].replace( + Const.FORWARD, Const.BACKWARD) if ModuleProcesser.module_node[forward_full_name] else None + ModuleProcesser.api_parent_node = None + if self.scope: + self.scope.begin_module(full_name) + + if torch_version_above_or_equal_2: + if Const.START in start_or_stop: + return pre_hook + else: + return end_hook else: - return end_hook + if Const.FORWARD in name_prefix and Const.START in start_or_stop: + return pre_hook + elif Const.BACKWARD in name_prefix: + return backward_hook + else: + return end_hook diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py index 048ab3f90..5755e1ecc 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py @@ -6,10 +6,9 @@ import json from collections import namedtuple from rich.table import Table from rich.console import Console -from .single_compare import single_benchmark_compare_wrap -from .utils import DispatchException -from msprobe.core.common.const import CompareConst -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.const import CompareConst, FileCheckConst +from msprobe.core.common.file_check import FileOpen, change_mode +from msprobe.pytorch.online_dispatch.single_compare import single_benchmark_compare_wrap from msprobe.pytorch.common.log import logger from msprobe.core.common.utils import CompareException @@ -42,6 +41,7 @@ def write_csv(data, filepath): with FileOpen(filepath, 'a', encoding='utf-8-sig') as f: writer = csv.writer(f) writer.writerows(data) + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) class Saver: diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py index 898df30b9..541218852 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py @@ -1,10 +1,8 @@ import os import time import json -from pathlib import Path -from multiprocessing import Manager, Pool +from multiprocessing import Pool -import yaml import torch from torch.utils._python_dispatch import TorchDispatchMode @@ -16,14 +14,15 @@ except ImportError: else: is_npu = True -from .dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \ - DispatchRunParam, DisPatchDataInfo -from .utils import get_callstack, data_to_cpu, logger_debug, logger_error, logger_warn, logger_logo, get_sys_info, \ - DispatchException -from .compare import Comparator -from msprobe.core.common.file_check import FileOpen -from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create +from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create, load_yaml from msprobe.core.common.const import Const, CompareConst +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.online_dispatch.dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \ + DispatchRunParam, DisPatchDataInfo +from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu, get_sys_info, DispatchException, COMPARE_LOGO +from msprobe.pytorch.online_dispatch.compare import Comparator +from msprobe.core.common.file_check import FileOpen, create_directory + current_time = time.strftime("%Y%m%d%H%M%S") RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" @@ -33,12 +32,12 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv" class PtdbgDispatch(TorchDispatchMode): def __init__(self, dump_mode=Const.OFF, api_list=None, debug=False, dump_path=None, tag=None, process_num=0): super(PtdbgDispatch, self).__init__() - logger_logo() + logger.info(COMPARE_LOGO) if not is_npu: - logger_error("Please confirm you run environment installed torch_npu!") + logger.error("Please confirm you run environment installed torch_npu!") return if dump_path is None: - logger_error("Please set dump_path when dump_mode is config!") + logger.error("Please set dump_path when dump_mode is config!") check_file_or_directory_path(dump_path, True) self.device_id = torch_npu._C._npu_getDevice() @@ -49,7 +48,7 @@ class PtdbgDispatch(TorchDispatchMode): self.single_api_index_dict = {} self.device_dump_path_cpu = None self.device_dump_path_npu = None - self.all_summery = [] + self.all_summary = [] self.call_stack_list = [] self.process_num = process_num self.filter_dump_api() @@ -60,8 +59,8 @@ class PtdbgDispatch(TorchDispatchMode): self.root_npu_path = os.path.join(self.root_path, f'npu') check_path_before_create(self.root_cpu_path) check_path_before_create(self.root_npu_path) - Path(self.root_cpu_path).mkdir(mode=0o750, parents=True, exist_ok=True) - Path(self.root_npu_path).mkdir(mode=0o750, parents=True, exist_ok=True) + create_directory(self.root_cpu_path) + create_directory(self.root_npu_path) self.result_csv_path = os.path.join(self.root_path, RESULT_FILE_NAME) self.detail_csv_path = os.path.join(self.root_path, DETAILS_FILE_NAME) @@ -70,13 +69,13 @@ class PtdbgDispatch(TorchDispatchMode): self.aten_ops_blacklist = [] self.npu_adjust_autogard = [] yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml") - self.load_yaml_file(yaml_path) + self.get_ops(yaml_path) self.lock = None if process_num > 0: self.pool = Pool(process_num) if debug: - logger_debug(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} ' + logger.info(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} ' f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], ' f'process[{process_num}]') @@ -85,17 +84,17 @@ class PtdbgDispatch(TorchDispatchMode): if not is_npu: return - logger_debug(f'start write compare csv: Rank[{self.device_id}], Pid[{os.getpid()}') + logger.info(f'start write compare csv: Rank[{self.device_id}], Pid[{os.getpid()}') if self.process_num > 0: self.pool.close() self.pool.join() - summery_path = os.path.join(self.root_cpu_path, f'summary.json') - if not os.path.exists(summery_path): - logger_error("Please check train log, An exception may have occurred!") + summary_path = os.path.join(self.root_cpu_path, f'summary.json') + if not os.path.exists(summary_path): + logger.error("Please check train log, An exception may have occurred!") return - check_file_or_directory_path(summery_path, False) - fp_handle = open(summery_path, "r") + check_file_or_directory_path(summary_path, False) + fp_handle = FileOpen(summary_path, "r") while True: json_line_data = fp_handle.readline() if json_line_data == '\n': @@ -103,7 +102,7 @@ class PtdbgDispatch(TorchDispatchMode): if len(json_line_data) == 0: break msg = json.loads(json_line_data) - self.all_summery[msg[0]] = msg[1] + self.all_summary[msg[0]] = msg[1] fp_handle.close() if self.debug_flag: @@ -111,20 +110,20 @@ class PtdbgDispatch(TorchDispatchMode): output_num = 0 total_num = 0 - for list_data in self.all_summery: + for list_data in self.all_summary: for data in list_data: - logger_debug(f'summery: Device[{self.device_id}], Pid[{os.getpid()}], Data[{data}]') + logger.info(f'summary: Device[{self.device_id}], Pid[{os.getpid()}], Data[{data}]') if "_input" in data[CompareConst.NPU_NAME]: input_num = input_num + 1 if "_output" in data[CompareConst.NPU_NAME]: output_num = output_num + 1 total_num = total_num + 1 - logger_debug(f'Dispatch exit: Device[{self.device_id}], Pid[{os.getpid()} Input[{input_num}] ' + logger.info(f'Dispatch exit: Device[{self.device_id}], Pid[{os.getpid()} Input[{input_num}] ' f'Output[{output_num}] Total[{total_num}] API_Total[{self.api_index}]]') def __torch_dispatch__(self, func, types, args=(), kwargs=None): if not is_npu: - logger_error("Please confirm you run environment installed torch_npu!") + logger.error("Please confirm you run environment installed torch_npu!") return func(*args, **kwargs) func_name_split_list = func.__name__.split(".") @@ -132,7 +131,7 @@ class PtdbgDispatch(TorchDispatchMode): try: aten_api_overload_name = func_name_split_list[1] except IndexError: - logger_error(f"Please check the func name {func.__name__}!") + logger.error(f"Please check the func name {func.__name__}!") return func(*args, **kwargs) self.enable_autogard(aten_api) @@ -151,7 +150,7 @@ class PtdbgDispatch(TorchDispatchMode): run_param = self.get_run_param(aten_api, func.__name__, aten_api_overload_name) if self.debug_flag: - logger_debug(f'Dispatch Info: Rank[{self.device_id}], Pid[{os.getpid()}], Func[{func.__name__}], ' + logger.info(f'Dispatch Info: Rank[{self.device_id}], Pid[{os.getpid()}], Func[{func.__name__}], ' f'Name[{run_param.aten_api}_{run_param.single_api_index}], ' f'Count[{self.api_index}], Sys[{get_sys_info()}]') @@ -175,21 +174,21 @@ class PtdbgDispatch(TorchDispatchMode): cpu_out = cpu_out.float() if self.process_num == 0: - self.all_summery.append([]) - data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summery, func, npu_out_cpu, cpu_out, self.lock) + self.all_summary.append([]) + data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, func, npu_out_cpu, cpu_out, self.lock) dispatch_workflow(run_param, data_info) else: self.lock.acquire() - self.all_summery.append([]) + self.all_summary.append([]) self.lock.release() run_param.process_flag = True if self.check_fun(func, run_param): - data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summery, None, npu_out_cpu, cpu_out, + data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out, self.lock) self.pool.apply_async(func=dispatch_multiprocess, args=(run_param, data_info), error_callback=error_call) else: - logger_error("can not get correct function please set process_num=0") + logger.error("can not get correct function please set process_num=0") return npu_out @staticmethod @@ -208,17 +207,16 @@ class PtdbgDispatch(TorchDispatchMode): time.sleep(1) time_now = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) if tag is None or not isinstance(tag, str): - logger_warn('There is not tag or the type of tag is not string.') + logger.warning('There is not tag or the type of tag is not string.') dir_name = f'msprobe_rank{self.device_id}_{time_now}' else: dir_name = f'msprobe_{tag}_rank{self.device_id}_{time_now}' return dir_name - def load_yaml_file(self, file_path): - with FileOpen(file_path, 'r') as f: - yaml_file = yaml.safe_load(f) - self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist') - self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard') + def get_ops(self, file_path): + yaml_file = load_yaml(file_path) + self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist') + self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard') def filter_dump_api(self): if self.dump_mode != Const.LIST or not self.dump_api_list: @@ -230,7 +228,7 @@ class PtdbgDispatch(TorchDispatchMode): if aten_api in aten_api_list: dump_api_list.append(aten_api) else: - logger_warn(f'{aten_api} is not aten api will not dump, please refer to torch.ops.aten') + logger.warning(f'{aten_api} is not aten api will not dump, please refer to torch.ops.aten') self.dump_api_list = dump_api_list def get_run_param(self, aten_api, func_name, aten_api_overload_name): @@ -257,16 +255,16 @@ class PtdbgDispatch(TorchDispatchMode): def check_param(self): if self.dump_mode not in Const.ONLINE_DUMP_MODE: - logger_error('The parameter "dump mode" can only be one of {}.'.format(Const.ONLINE_DUMP_MODE)) + logger.error('The parameter "dump mode" can only be one of {}.'.format(Const.ONLINE_DUMP_MODE)) raise DispatchException(DispatchException.INVALID_PARAMETER) if not isinstance(self.dump_api_list, list): - logger_error('The type of parameter "api_list" can only be list.') + logger.error('The type of parameter "api_list" can only be list.') raise DispatchException(DispatchException.INVALID_PARAMETER) if not isinstance(self.debug_flag, bool): - logger_error('The type of parameter "debug" can only be bool.') + logger.error('The type of parameter "debug" can only be bool.') raise DispatchException(DispatchException.INVALID_PARAMETER) if not isinstance(self.process_num, int) or self.process_num < 0: - logger_error('The type of parameter "process_num" can only be int and it should not be less than 0.') + logger.error('The type of parameter "process_num" can only be int and it should not be less than 0.') raise DispatchException(DispatchException.INVALID_PARAMETER) def enable_autogard(self, aten_api): diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py index f83b6fc9f..82cb56e3d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py @@ -5,11 +5,10 @@ from datetime import datetime, timezone import pandas as pd import torch -from .utils import np_save_data, logger_debug, logger_error, logger_warn, logger_user, COLOR_RED, COLOR_GREEN, \ - COLOR_RESET, CSV_COLUMN_NAME -from msprobe.core.common.file_check import FileOpen, change_mode -from msprobe.core.common.const import CompareConst, FileCheckConst, Const from msprobe.pytorch.common.log import logger +from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.utils import save_npy + class DispatchRunParam: def __init__(self, debug_flag, device_id, root_npu_path, root_cpu_path, process_num, comparator): @@ -32,10 +31,10 @@ class DispatchRunParam: class DisPatchDataInfo: - def __init__(self, cpu_args, cpu_kwargs, all_summery, func, npu_out_cpu, cpu_out, lock): + def __init__(self, cpu_args, cpu_kwargs, all_summary, func, npu_out_cpu, cpu_out, lock): self.cpu_args = cpu_args self.cpu_kwargs = cpu_kwargs - self.all_summery = all_summery + self.all_summary = all_summary self.func = func self.npu_out_cpu = npu_out_cpu self.cpu_out = cpu_out @@ -57,7 +56,7 @@ class TimeStatistics: def __enter__(self): if self.debug: self.time = datetime.now(tz=timezone.utc) - logger_debug(f'Time[{self.tag}]-ENTER: Dev[{self.device}], Pid[{os.getpid()}], Fun[{self.fun}], ' \ + logger.info(f'Time[{self.tag}]-ENTER: Dev[{self.device}], Pid[{os.getpid()}], Fun[{self.fun}], ' \ f'Id[{self.index}]') def __exit__(self, exc_type, exc_val, exc_tb): @@ -68,9 +67,9 @@ class TimeStatistics: hot_time_cost = "Hotspot " + time_cost if cost_time.total_seconds() > self.timeout: - logger_debug(hot_time_cost) + logger.info(hot_time_cost) else: - logger_debug(time_cost) + logger.info(time_cost) def support_basic_type(data): @@ -87,24 +86,25 @@ def dump_data(data, prefix, dump_path): elif support_basic_type(data): if isinstance(data, torch.Tensor) and data.is_meta: return - # dump data may greater than summery_list collect - np_save_data(data, prefix, dump_path) + # dump data may greater than summary_list collect + path = os.path.join(dump_path, f'{prefix}.npy') + save_npy(data, path) -def save_temp_summery(api_index, single_api_summery, path, lock): - summery_path = os.path.join(path, f'summery.json') +def save_temp_summary(api_index, single_api_summary, path, lock): + summary_path = os.path.join(path, f'summary.json') lock.acquire() - with FileOpen(summery_path, "a") as f: - json.dump([api_index, single_api_summery], f) + with FileOpen(summary_path, "a") as f: + json.dump([api_index, single_api_summary], f) f.write('\n') lock.release() def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo): cpu_args, cpu_kwargs = data_info.cpu_args, data_info.cpu_kwargs - all_summery, func = data_info.all_summery, data_info.func + all_summary, func = data_info.all_summary, data_info.func npu_out_cpu, cpu_out, lock = data_info.npu_out_cpu, data_info.cpu_out, data_info.lock - single_api_summery = [] + single_api_summary = [] prefix_input = f'{run_param.aten_api}_{run_param.single_api_index}_input' prefix_output = f'{run_param.aten_api}_{run_param.single_api_index}_output' @@ -127,9 +127,9 @@ def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo): dump_data(npu_out_cpu, prefix_output, run_param.root_npu_path) if run_param.process_num == 0: - all_summery[run_param.api_index - 1] = copy.deepcopy(single_api_summery) + all_summary[run_param.api_index - 1] = copy.deepcopy(single_api_summary) else: - save_temp_summery(run_param.api_index - 1, single_api_summery, run_param.root_cpu_path, lock) + save_temp_summary(run_param.api_index - 1, single_api_summary, run_param.root_cpu_path, lock) def get_torch_func(run_param): @@ -155,32 +155,3 @@ def dispatch_multiprocess(run_param, dispatch_data_info): def error_call(err): logger.error(f'multiprocess {err}') - -def save_csv(all_summery, call_stack_list, csv_path): - df = pd.DataFrame(columns=CSV_COLUMN_NAME) - - for index, list_data in enumerate(all_summery): - for data in list_data: - csv_row_data = {CompareConst.NPU_NAME: data[CompareConst.NPU_NAME], - CompareConst.BENCH_NAME: data[CompareConst.BENCH_NAME], - CompareConst.NPU_DTYPE: data[CompareConst.NPU_DTYPE], - CompareConst.BENCH_DTYPE: data[CompareConst.BENCH_DTYPE], - CompareConst.NPU_SHAPE: data[CompareConst.NPU_SHAPE], - CompareConst.BENCH_SHAPE: data[CompareConst.BENCH_SHAPE], - CompareConst.NPU_MAX: data[CompareConst.NPU_MAX], - CompareConst.NPU_MIN: data[CompareConst.NPU_MIN], - CompareConst.NPU_MEAN: data[CompareConst.NPU_MEAN], - CompareConst.BENCH_MAX: data[CompareConst.BENCH_MAX], - CompareConst.BENCH_MIN: data[CompareConst.BENCH_MIN], - CompareConst.BENCH_MEAN: data[CompareConst.BENCH_MEAN], - CompareConst.COSINE: data[CompareConst.COSINE], - CompareConst.MAX_ABS_ERR: data[CompareConst.MAX_ABS_ERR], - CompareConst.MAX_RELATIVE_ERR: data[CompareConst.MAX_RELATIVE_ERR], - CompareConst.ACCURACY: data[CompareConst.ACCURACY], - CompareConst.STACK: call_stack_list[index], - CompareConst.ERROR_MESSAGE: data[CompareConst.ERROR_MESSAGE]} - row_df = pd.DataFrame.from_dict(csv_row_data, orient='index').T - df = pd.concat([df, row_df]) - - df.to_csv(csv_path, index=False) - change_mode(csv_path, FileCheckConst.DATA_FILE_AUTHORITY) diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/single_compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/single_compare.py index aa0afa4e4..83a6bafc6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/single_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/single_compare.py @@ -3,15 +3,15 @@ from functools import wraps import torch from prettytable import PrettyTable from collections import namedtuple -from .utils import logger_user, logger_debug +from msprobe.pytorch.common.log import logger def func_log_wrapper(): def _out_wrapper(func): @wraps(func) def _in_wrapper(*kargs, **kwargs): - logger_debug("start to run: {}".format(func.__name__)) + logger.info(f"start to run: {func.__name__}") x = func(*kargs, **kwargs) - logger_debug("end to run: {}".format(func.__name__)) + logger.info(f"end to run: {func.__name__}") return x return _in_wrapper @@ -157,15 +157,15 @@ class SingleBenchmarkAccuracyCompare: max_rel_idx=max_rel_idx ) acc_result.get_result(eb_thd, error_thd) - return CompareResultInfo(acc_result, error_thd, eb_thd, None) - return None + return CompareResultInfo(acc_result, error_thd, eb_thd, None) + @classmethod @func_log_wrapper() def compute_binary_diff(cls, npu_out, bench_out): result = torch.equal(npu_out, bench_out) if result: - logger_user("二进制精度比对通过, 无需单标杆比对法验证") + logger.info("二进制精度比对通过, 无需单标杆比对法验证") return SingleBenchmarkAccuracyResult(result=result, max_abs_diff=0, max_rel_diff=0, error_balance=0) @classmethod @@ -191,7 +191,7 @@ class SingleBenchmarkAccuracyCompare: zeros = torch.zeros_like(npu_out) diff_value = torch.subtract(npu_out, bench_out) diff_abs = torch.abs(diff_value) - abs_mask_idx = torch.where(torch.abs(bench_out) < benchmark_standard.small_value, ones, zeros) + abs_mask_idx = torch.where(torch.abs(bench_out) >= benchmark_standard.small_value, ones, zeros) abs_err_idx = torch.where(diff_abs > error_thd, ones, zeros) abs_err_idx = abs_err_idx * abs_mask_idx abs_err = diff_abs[torch.where(abs_err_idx == 1)] @@ -301,7 +301,7 @@ class SingleBenchSummary: table.add_row(["max_rel_diff", self.max_rel_diff, self.error_thd]) table.add_row(["max_rel_idx", self.max_rel_idx, "-"]) - logger_user(table) + logger.info(table) def to_column_value(self): return [self.bench_dtype, self.npu_dtype, self.shape, self.error_balance, @@ -354,7 +354,7 @@ def calc_status_details_dict(npu_out, bench_out, high_precision, summary): def calc_status_details_tensor(npu_out, bench_out, high_precision, summary): - return single_benchmark_compare(bench_out, npu_out) + return single_benchmark_compare(npu_out, bench_out) def calc_status_details_builtin(npu_out, bench_out, summary): diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py index fec3e0b00..52b6a637b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py @@ -1,6 +1,5 @@ import os import inspect -import logging import psutil import torch import numpy as np @@ -12,8 +11,7 @@ except ImportError: else: pta_cpu_device = torch.device("cpu") -from msprobe.core.common.const import CompareConst, FileCheckConst -from msprobe.core.common.file_check import change_mode +from msprobe.core.common.const import CompareConst cpu_device = torch._C.device("cpu") COLOR_RED = '\033[31m' @@ -69,19 +67,6 @@ def get_callstack(): return callstack -def np_save_data(data, file_name, data_path): - try: - if hasattr(data, "numpy"): - data = data.numpy() - dump_path = os.path.join(data_path, f'{file_name}.npy') - np.save(dump_path, data) - change_mode(dump_path, FileCheckConst.DATA_FILE_AUTHORITY) - except Exception as e: - logger_error("save numpy failed, error: {}".format(e)) - finally: - pass - - def data_to_cpu(data, deep, data_cpu): global cpu_device list_cpu = [] @@ -124,47 +109,6 @@ def data_to_cpu(data, deep, data_cpu): return data -def get_mp_logger(): - logger = logging.getLogger(__name__) - if not logger.handlers: - logger.setLevel(logging.INFO) - handler = logging.StreamHandler() - formatter = logging.Formatter('%(asctime)s %(message)s') - logger.propagate = True - handler.setFormatter(formatter) - logger.addHandler(handler) - return logger.info - - -def logger_debug(mesg): - logger = get_mp_logger() - logger(f'DEBUG ' + mesg) - - -def logger_info(mesg): - logger = get_mp_logger() - logger(f'INFO ' + mesg) - - -def logger_warn(mesg): - logger = get_mp_logger() - logger(f'{COLOR_YELLOW}WARNING {mesg} {COLOR_RESET}') - - -def logger_error(mesg): - logger = get_mp_logger() - logger(f'{COLOR_RED}ERROR {mesg} {COLOR_RESET}') - - -def logger_user(mesg): - logger = get_mp_logger() - logger(mesg) - - -def logger_logo(): - logger_user(f'{COLOR_CYAN}{COMPARE_LOGO} {COLOR_RESET}') - - def get_sys_info(): mem = psutil.virtual_memory() cpu_percent = psutil.cpu_percent(interval=1) diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/compare.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/compare.py index 2b091c59e..add1dea38 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/compare.py @@ -22,6 +22,8 @@ from collections import namedtuple from msprobe.pytorch.parse_tool.lib.utils import Util from msprobe.pytorch.parse_tool.lib.config import Const from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException +from msprobe.core.common.utils import write_csv, save_npy_to_txt, load_npy +from msprobe.core.common.file_check import FileChecker, create_directory class Compare: @@ -36,7 +38,7 @@ class Compare: self.log.info("Compare finished!!") def compare_vector(self, my_dump_path, golden_dump_path, result_dir, msaccucmp_path): - self.util.create_dir(result_dir) + create_directory(result_dir) self.util.check_path_valid(result_dir) call_msaccucmp = self.util.check_msaccucmp(msaccucmp_path) cmd = '%s %s compare -m %s -g %s -out %s' % ( @@ -65,7 +67,7 @@ class Compare: self.util.print_panel("\n".join(summary_txt)) def convert(self, dump_file, data_format, output, msaccucmp_path): - self.util.create_dir(output) + create_directory(output) self.util.check_path_valid(output) call_msaccucmp = self.util.check_msaccucmp(msaccucmp_path) if data_format: @@ -84,21 +86,12 @@ class Compare: if left is None or right is None: raise ParseException("invalid input or output") if self.util.check_path_valid(left) and self.util.check_path_valid(right): - try: - left_data = np.load(left) - right_data = np.load(right) - except UnicodeError as e: - self.log.error("%s %s" % ("UnicodeError", str(e))) - self.log.warning("Please check the npy file") - raise ParseException(ParseException.PARSE_UNICODE_ERROR) from e - except IOError: - self.log.error("Failed to load npy %s or %s." % (left, right)) - raise ParseException(ParseException.PARSE_LOAD_NPY_ERROR) from e - + left_data = load_npy(left) + right_data = load_npy(right) # save to txt if save_txt: - self.util.save_npy_to_txt(left_data, left + ".txt") - self.util.save_npy_to_txt(right_data, right + ".txt") + save_npy_to_txt(left_data, left + ".txt") + save_npy_to_txt(right_data, right + ".txt") # compare data (total_cnt, all_close, cos_sim, err_percent) = self.do_compare_data(left_data, right_data, rl, al, diff_count) content = ['Left:', ' ├─ NpyFile: %s' % left] @@ -158,10 +151,9 @@ class Compare: return res def compare_npy(self, file, bench_file, output_path): - if self.util.check_path_valid(file): - data = np.load(file) - if self.util.check_path_valid(bench_file): - bench_data = np.load(bench_file) + if self.util.check_path_valid(file) and self.util.check_path_valid(bench_file): + data = load_npy(file) + bench_data = load_npy(bench_file) shape, dtype = data.shape, data.dtype bench_shape, bench_dtype = bench_data.shape, bench_data.dtype filename = os.path.basename(file) @@ -184,7 +176,7 @@ class Compare: rel_diff_max = np.max(rel_error) compare_result = [[filename, bench_filename, data_mean, bench_data_mean, md5_consistency, abs_diff_max, rel_diff_max]] - self.util.write_csv(compare_result, output_path) + write_csv(compare_result, output_path) def compare_all_file_in_directory(self, my_dump_dir, golden_dump_dir, output_path): if not (self.util.is_subdir_count_equal(my_dump_dir, golden_dump_dir) @@ -231,7 +223,7 @@ class Compare: "Max Abs Error", "Max Relative Error" ]] - self.util.write_csv(title_rows, output_path) + write_csv(title_rows, output_path) my_ordered_subdirs = self.util.get_sorted_subdirectories_names(my_dump_dir) golden_ordered_subdirs = self.util.get_sorted_subdirectories_names(golden_dump_dir) @@ -249,7 +241,9 @@ class Compare: def convert_api_dir_to_npy(self, dump_dir, param, output_dir, msaccucmp_path): dump_dir = self.util.path_strip(dump_dir) - for root, _, files in os.walk(dump_dir): + for root, _, files in os.walk(dump_dir, topdown=True): + path_checker = FileChecker(root) + path_checker.common_check() for file in files: file_path = os.path.join(root, file) file_name = os.path.basename(file_path) @@ -260,3 +254,8 @@ class Compare: timestamp = parts[-1] output_path = os.path.join(output_dir, op_name, timestamp) self.convert_dump_to_npy(file_path, param, output_path, msaccucmp_path) + path_depth = root.count(os.sep) + if path_depth <= Const.MAX_TRAVERSAL_DEPTH: + yield root, _, files + else: + _[:] = [] diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/config.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/config.py index a9a8b2b00..176295ad9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/config.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/config.py @@ -38,6 +38,7 @@ class Const: PKL_SUFFIX = ".pkl" DIRECTORY_LENGTH = 4096 FILE_NAME_LENGTH = 255 + MAX_TRAVERSAL_DEPTH = 5 FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' ONE_GB = 1 * 1024 * 1024 * 1024 TEN_GB = 10 * 1024 * 1024 * 1024 diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_tool.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_tool.py index 9a47dc54c..879329ae5 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_tool.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_tool.py @@ -23,7 +23,7 @@ from msprobe.pytorch.parse_tool.lib.utils import Util from msprobe.pytorch.parse_tool.lib.compare import Compare from msprobe.pytorch.parse_tool.lib.visualization import Visualization from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception, ParseException - +from msprobe.core.common.file_check import create_directory class ParseTool: def __init__(self): @@ -33,7 +33,7 @@ class ParseTool: @catch_exception def prepare(self): - self.util.create_dir(Const.DATA_ROOT_DIR) + create_directory(Const.DATA_ROOT_DIR) @catch_exception def do_vector_compare(self, args): @@ -112,8 +112,8 @@ class ParseTool: args = parser.parse_args(argv) self.util.check_path_valid(args.my_dump_path) self.util.check_path_valid(args.golden_dump_path) - self.util.check_path_format(args.my_dump_path, Const.NPY_SUFFIX) - self.util.check_path_format(args.golden_dump_path, Const.NPY_SUFFIX) + self.util.check_file_path_format(args.my_dump_path, Const.NPY_SUFFIX) + self.util.check_file_path_format(args.golden_dump_path, Const.NPY_SUFFIX) compare_data_args = namedtuple('compare_data_args', ['my_dump_path', 'golden_dump_path', 'save', 'rtol', 'atol', 'count']) compare_data_args.__new__.__defaults__ = (False, 0.001, 0.001, 20) res = compare_data_args(args.my_dump_path, args.golden_dump_path, args.save, args.rtol, args.atol, args.count) diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py index 17a01f20f..8f48f2784 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py @@ -31,8 +31,8 @@ from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException from msprobe.core.common.file_check import change_mode, check_other_user_writable,\ check_path_executable, check_path_owner_consistent from msprobe.core.common.const import FileCheckConst -from msprobe.core.common.file_check import FileOpen -from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create +from msprobe.core.common.file_check import FileOpen, FileChecker +from msprobe.core.common.utils import check_file_or_directory_path, remove_path from msprobe.pytorch.common.log import logger @@ -57,12 +57,7 @@ except ImportError as err: class Util: def __init__(self): self.ms_accu_cmp = None - logging.basicConfig( - level=Const.LOG_LEVEL, - format="%(asctime)s (%(process)d) -[%(levelname)s]%(message)s", - datefmt="%Y-%m-%d %H:%M:%S" - ) - self.log = logging.getLogger() + self.log = logger self.python = sys.executable @staticmethod @@ -82,6 +77,8 @@ class Util: @staticmethod def get_subdir_count(self, directory): subdir_count = 0 + path_checker = FileChecker(directory) + path_checker.common_check() for _, dirs, _ in os.walk(directory): subdir_count += len(dirs) break @@ -90,8 +87,15 @@ class Util: @staticmethod def get_subfiles_count(self, directory): file_count = 0 - for _, _, files in os.walk(directory): + for root, _, files in os.walk(directory, topdown=True): + path_checker = FileChecker(root) + path_checker.common_check() file_count += len(files) + path_depth = root.count(os.sep) + if path_depth <= Const.MAX_TRAVERSAL_DEPTH: + yield root, _, files + else: + _[:] = [] return file_count @staticmethod @@ -128,21 +132,9 @@ class Util: md5_hash = hashlib.md5(np_bytes) return md5_hash.hexdigest() - @staticmethod - def write_csv(self, data, filepath): - need_change_mode = False - if not os.path.exists(filepath): - need_change_mode = True - with FileOpen(filepath, 'a') as f: - writer = csv.writer(f) - writer.writerows(data) - if need_change_mode: - change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) - @staticmethod def deal_with_dir_or_file_inconsistency(self, output_path): - if os.path.exists(output_path): - os.remove(output_path) + remove_path(output_path) raise ParseException("Inconsistent directory structure or file.") @staticmethod @@ -160,10 +152,17 @@ class Util: @staticmethod def dir_contains_only(self, path, endfix): - for _, _, files in os.walk(path): + for root, _, files in os.walk(path, topdown=True): + path_checker = FileChecker(root) + path_checker.common_check() for file in files: if not file.endswith(endfix): return False + path_depth = root.count(os.sep) + if path_depth <= Const.MAX_TRAVERSAL_DEPTH: + yield root, _, files + else: + _[:] = [] return True @staticmethod @@ -188,7 +187,7 @@ class Util: if not cmd: self.log.error("Commond is None") return -1 - self.log.debug("[RUN CMD]: %s", cmd) + self.log.info("[RUN CMD]: %s", cmd) cmd = cmd.split(" ") complete_process = subprocess.run(cmd, shell=False) return complete_process.returncode @@ -208,7 +207,7 @@ class Util: "Check msaccucmp failed in dir %s. This is not a correct msaccucmp file" % target_file) raise ParseException(ParseException.PARSE_MSACCUCMP_ERROR) result = subprocess.run( - [self.python, target_file, "--help"], stdout=subprocess.PIPE) + [self.python, target_file, "--help"], stdout=subprocess.PIPE, shell=False) if result.returncode == 0: self.log.info("Check [%s] success.", target_file) else: @@ -217,41 +216,12 @@ class Util: raise ParseException(ParseException.PARSE_MSACCUCMP_ERROR) return target_file - def create_dir(self, path): - path = self.path_strip(path) - if os.path.exists(path): - return - self.check_path_name(path) - try: - os.makedirs(path, mode=FileCheckConst.DATA_DIR_AUTHORITY) - except OSError as e: - self.log.error("Failed to create %s.", path) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) from e - def gen_npy_info_txt(self, source_data): (shape, dtype, max_data, min_data, mean) = \ self.npy_info(source_data) return \ '[Shape: %s] [Dtype: %s] [Max: %s] [Min: %s] [Mean: %s]' % (shape, dtype, max_data, min_data, mean) - def save_npy_to_txt(self, data, dst_file='', align=0): - if os.path.exists(dst_file): - self.log.info("Dst file %s exists, will not save new one.", dst_file) - return - shape = data.shape - data = data.flatten() - if align == 0: - align = 1 if len(shape) == 0 else shape[-1] - elif data.size % align != 0: - pad_array = np.zeros((align - data.size % align,)) - data = np.append(data, pad_array) - check_path_before_create(dst_file) - try: - np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g') - except Exception as e: - self.log.error("An unexpected error occurred: %s when savetxt to %s" % (str(e)), dst_file) - change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY) - def list_convert_files(self, path, external_pattern=""): return self.list_file_with_pattern( path, Const.OFFLINE_DUMP_CONVERT_PATTERN, external_pattern, self._gen_npu_dump_convert_file_info @@ -278,27 +248,8 @@ class Util: def check_path_valid(self, path): path = self.path_strip(path) - if not path or not os.path.exists(path): - self.log.error("The path %s does not exist." % path) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - if os.path.islink(path): - self.log.error('The file path {} is a soft link.'.format(path)) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \ - Const.FILE_NAME_LENGTH: - self.log.error('The file path length exceeds limit.') - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - if not re.match(Const.FILE_PATTERN, os.path.realpath(path)): - self.log.error('The file path {} contains special characters.'.format(path)) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - if os.path.isfile(path): - file_size = os.path.getsize(path) - if path.endswith(Const.PKL_SUFFIX) and file_size > Const.ONE_GB: - self.log.error('The file {} size is greater than 1GB.'.format(path)) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - if path.endswith(Const.NPY_SUFFIX) and file_size > Const.TEN_GB: - self.log.error('The file {} size is greater than 10GB.'.format(path)) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) + path_checker = FileChecker(path) + path_checker.common_check() return True def check_files_in_path(self, path): @@ -326,17 +277,24 @@ class Util: self.check_path_valid(path) file_list = {} re_pattern = re.compile(pattern) - for dir_path, _, file_names in os.walk(path, followlinks=True): + for dir_path, _, file_names in os.walk(path, topdown=True): + path_checker = FileChecker(dir) + path_checker.common_check() for name in file_names: match = re_pattern.match(name) if not match: continue - if extern_pattern != '' and not re.match(extern_pattern, name): + if extern_pattern != '' and re_pattern.match(extern_pattern) and not re.match(extern_pattern, name): continue file_list[name] = gen_info_func(name, match, dir_path) + path_depth = dir_path.count(os.sep) + if path_depth <= Const.MAX_TRAVERSAL_DEPTH: + yield dir_path, _, file_names + else: + _[:] = [] return file_list - def check_path_format(self, path, suffix): + def check_file_path_format(self, path, suffix): if os.path.isfile(path): if not path.endswith(suffix): self.log.error("%s is not a %s file." % (path, suffix)) @@ -348,15 +306,6 @@ class Util: self.log.error("The file path %s is invalid" % path) raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - def check_path_name(self, path): - if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \ - Const.FILE_NAME_LENGTH: - self.log.error('The file path length exceeds limit.') - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - if not re.match(Const.FILE_PATTERN, os.path.realpath(path)): - self.log.error('The file path {} contains special characters.'.format(path)) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - def check_str_param(self, param): if len(param) > Const.FILE_NAME_LENGTH: self.log.error('The parameter length exceeds limit') diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py index 5e37b58d0..8ca69af85 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py @@ -21,6 +21,7 @@ from msprobe.pytorch.parse_tool.lib.config import Const from msprobe.pytorch.parse_tool.lib.utils import Util from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.utils import save_npy_to_txt, load_npy class Visualization: @@ -28,12 +29,7 @@ class Visualization: self.util = Util() def print_npy_summary(self, target_file): - try: - np_data = np.load(target_file, allow_pickle=True) - except UnicodeError as e: - self.util.log.error("%s %s" % ("UnicodeError", str(e))) - self.util.log.warning("Please check the npy file") - raise ParseException(ParseException.PARSE_UNICODE_ERROR) from e + np_data = load_npy(target_file, enable_pickle=True) table = self.util.create_table('', ['Index', 'Data']) flatten_data = np_data.flatten() tablesize = 8 @@ -43,18 +39,18 @@ class Visualization: summary = ['[yellow]%s[/yellow]' % self.util.gen_npy_info_txt(np_data), 'Path: %s' % target_file, "TextFile: %s.txt" % target_file] self.util.print_panel(self.util.create_columns([table, "\n".join(summary)]), target_file) - self.util.save_npy_to_txt(np_data, target_file + ".txt") + save_npy_to_txt(np_data, target_file + ".txt") def print_npy_data(self, file_name): file_name = self.util.path_strip(file_name) self.util.check_path_valid(file_name) - self.util.check_path_format(file_name, Const.NPY_SUFFIX) + self.util.check_file_path_format(file_name, Const.NPY_SUFFIX) return self.print_npy_summary(file_name) def parse_pkl(self, path, api_name): path = self.util.path_strip(path) self.util.check_path_valid(path) - self.util.check_path_format(path, Const.PKL_SUFFIX) + self.util.check_file_path_format(path, Const.PKL_SUFFIX) self.util.check_str_param(api_name) with FileOpen(path, "r") as pkl_handle: title_printed = False diff --git a/debug/accuracy_tools/msprobe/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py index ceec92a63..4b36b350b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pt_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py @@ -4,19 +4,36 @@ import os from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.core.common.file_check import FileOpen from msprobe.core.common.const import Const -from msprobe.pytorch.hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps +from msprobe.pytorch.hook_module.utils import get_ops +from msprobe.core.grad_probe.constant import level_adp +from msprobe.core.grad_probe.utils import check_numeral_list_ascend class TensorConfig(BaseConfig): def __init__(self, json_config): super().__init__(json_config) + self.online_run_ut = json_config.get("online_run_ut", False) + self.nfs_path = json_config.get("nfs_path", "") + self.host = json_config.get("host", "") + self.port = json_config.get("port", -1) + self.tls_path = json_config.get("tls_path", "") self.check_config() self._check_file_format() + self._check_tls_path_config() def _check_file_format(self): if self.file_format is not None and self.file_format not in ["npy", "bin"]: raise Exception("file_format is invalid") + def _check_tls_path_config(self): + if self.tls_path: + if not os.path.exists(self.tls_path): + raise Exception("tls_path: %s does not exist" % self.tls_path) + if not os.path.exists(os.path.join(self.tls_path, "client.key")): + raise Exception("tls_path does not contain client.key") + if not os.path.exists(os.path.join(self.tls_path, "client.crt")): + raise Exception("tls_path does not contain client.crt") + class StatisticsConfig(BaseConfig): def __init__(self, json_config): @@ -64,12 +81,19 @@ class FreeBenchmarkCheckConfig(BaseConfig): class RunUTConfig(BaseConfig): - WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps) + WrapApi = get_ops() + def __init__(self, json_config): super().__init__(json_config) self.white_list = json_config.get("white_list", Const.DEFAULT_LIST) self.black_list = json_config.get("black_list", Const.DEFAULT_LIST) self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH) + self.is_online = json_config.get("is_online", False) + self.nfs_path = json_config.get("nfs_path", "") + self.host = json_config.get("host", "") + self.port = json_config.get("port", -1) + self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST) + self.tls_path = json_config.get("tls_path", "") self.check_run_ut_config() @classmethod @@ -86,11 +110,44 @@ class RunUTConfig(BaseConfig): def check_error_data_path_config(cls, error_data_path): if not os.path.exists(error_data_path): raise Exception("error_data_path: %s does not exist" % error_data_path) - + + @classmethod + def check_nfs_path_config(cls, nfs_path): + if nfs_path and not os.path.exists(nfs_path): + raise Exception("nfs_path: %s does not exist" % nfs_path) + + @classmethod + def check_tls_path_config(cls, tls_path): + if tls_path: + if not os.path.exists(tls_path): + raise Exception("tls_path: %s does not exist" % tls_path) + if not os.path.exists(os.path.join(tls_path, "server.key")): + raise Exception("tls_path does not contain server.key") + if not os.path.exists(os.path.join(tls_path, "server.crt")): + raise Exception("tls_path does not contain server.crt") + def check_run_ut_config(self): RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list) RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list) RunUTConfig.check_error_data_path_config(self.error_data_path) + RunUTConfig.check_nfs_path_config(self.nfs_path) + RunUTConfig.check_tls_path_config(self.tls_path) + + +class GradToolConfig(BaseConfig): + def __init__(self, json_config): + super().__init__(json_config) + self.grad_level = json_config.get("grad_level", "L1") + self.param_list = json_config.get("param_list", []) + self.bounds = json_config.get("bounds", [-1, 0, 1]) + self._check_config() + + def _check_config(self): + if self.grad_level not in level_adp.keys(): + raise Exception(f"grad_level must be one of {level_adp.keys()}") + if not isinstance(self.param_list, list): + raise Exception(f"param_list must be a list") + check_numeral_list_ascend(self.bounds) def parse_task_config(task, json_config): @@ -110,6 +167,9 @@ def parse_task_config(task, json_config): elif task == Const.RUN_UT: config_dic = json_config.get(Const.RUN_UT, default_dic) return RunUTConfig(config_dic) + elif task == Const.GRAD_PROBE: + config_dic = json_config.get(Const.GRAD_PROBE, default_dic) + return GradToolConfig(config_dic) else: return StatisticsConfig(default_dic) @@ -117,7 +177,7 @@ def parse_task_config(task, json_config): def parse_json_config(json_file_path, task): if not json_file_path: config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) - json_file_path = os.path.join(os.path.join(config_dir, "config"), "config.json") + json_file_path = os.path.join(config_dir, "config.json") with FileOpen(json_file_path, 'r') as file: json_config = json.load(file) common_config = CommonConfig(json_config) diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index aa22cd0b3..22a17f36f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -1,11 +1,11 @@ import functools import os -from pathlib import Path + +from collections import namedtuple import torch -from packaging import version -from msprobe.core.common.const import Const, FileCheckConst +from msprobe.core.common.const import Const from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException -from msprobe.core.common.file_check import FileChecker, check_path_before_create +from msprobe.core.common.file_check import create_directory from msprobe.core.data_dump.data_collector import build_data_collector from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs from msprobe.core.data_dump.scope import BaseScope @@ -15,9 +15,10 @@ from msprobe.pytorch.hook_module import remove_dropout from msprobe.pytorch.hook_module.api_registry import api_register from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.module_processer import ModuleProcesser +from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData +torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' -if version.parse(torch.__version__) >= version.parse('2.0.0'): - from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook +HookFn = namedtuple('hookFn', ['pre_hook', 'forward_hook', 'backward_hook', 'forward_hook_torch_version_below_2']) class Service: @@ -31,6 +32,7 @@ class Service: self.first_start = True self.current_rank = None self.dump_iter_dir = None + self.attl = None @staticmethod def forward_backward_dump_end(): @@ -45,6 +47,8 @@ class Service: if not self.switch: return args, kwargs + if self.config.online_run_ut: + return None, None if self.data_collector: module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None) self.data_collector.pre_forward_data_collect(api_or_module_name, module, pid, module_input_output) @@ -57,6 +61,14 @@ class Service: if not self.switch: return None + + if self.config.online_run_ut: + if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name): + return None + api_data = ApiData(name[:-1], args, kwargs, output, self.current_iter, self.current_rank) + self.attl_send(api_data) + return None + if self.data_collector: module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output) @@ -64,6 +76,9 @@ class Service: return self.data_collector.get_forward_new_output() return output + def forward_hook_torch_version_below_2(api_or_module_name, module, args, output): + return forward_hook(api_or_module_name, module, args, {}, output) + def backward_hook(api_or_module_name, module, grad_input, grad_output): if module_type == BaseScope.Module_Type_Module: api_or_module_name = module.mindstudio_reserved_name @@ -71,6 +86,14 @@ class Service: if not self.switch: return + + if self.config.online_run_ut: + if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name): + return + api_data = ApiData(name[:-1], grad_input, {}, grad_output, self.current_iter, self.current_rank) + self.attl_send(api_data) + return + if self.data_collector: # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序 module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input) @@ -79,22 +102,11 @@ class Service: pid = os.getpid() forward_name_template = name + Const.FORWARD backward_name_template = name + Const.BACKWARD - pre_forward_hook = functools.partial(pre_hook, forward_name_template) - forward_hook = functools.partial(forward_hook, forward_name_template) - backward_hook = functools.partial(backward_hook, backward_name_template) - return pre_forward_hook, forward_hook, backward_hook - - def hook_optimizer(self, model): - def optimizer_pre_step_hook(optimizer, args, kwargs): - self.stop() - self.step() - - def optimizer_post_step_hook(optimizer, args, kwargs): - self.start(model) - - - register_optimizer_step_pre_hook(optimizer_pre_step_hook) - register_optimizer_step_post_hook(optimizer_post_step_hook) + pre_forward_hook_fn = functools.partial(pre_hook, forward_name_template) + forward_hook_fn = functools.partial(forward_hook, forward_name_template) + backward_hook_fn = functools.partial(backward_hook, backward_name_template) + forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2, forward_name_template) + return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn) def step(self): self.current_iter += 1 @@ -106,6 +118,9 @@ class Service: def start(self, model, api_origin=False): self.model = model if self.config.step and self.current_iter > max(self.config.step): + if self.config.online_run_ut: + # send stop signal if online_run_ut + self.attl_stop() self.stop() raise Exception("msprobe: exit after iteration {}".format(max(self.config.step))) if self.config.step and self.current_iter not in self.config.step: @@ -115,6 +130,7 @@ class Service: self.current_rank = get_rank_if_initialized() except DistributedNotInitializedError: self.current_rank = None + self.attl_init() if self.config.rank and self.current_rank not in self.config.rank: return @@ -124,7 +140,7 @@ class Service: api_register.api_modularity() self.switch = True logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ") - if self.config.level != "L2": + if self.config.level != "L2" and not self.config.online_run_ut: self.create_dirs() logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.") @@ -136,22 +152,19 @@ class Service: if self.config.rank and self.current_rank not in self.config.rank: return self.switch = False + if self.config.online_run_ut: + return self.data_collector.write_json() def create_dirs(self): - check_path_before_create(self.config.dump_path) - if not os.path.exists(self.config.dump_path): - Path(self.config.dump_path).mkdir(mode=0o750, exist_ok=True) - file_check = FileChecker(self.config.dump_path, FileCheckConst.DIR) - file_check.common_check() + create_directory(self.config.dump_path) self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}") cur_rank = self.current_rank if self.current_rank is not None else '' dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") - if not os.path.exists(dump_dir): - Path(dump_dir).mkdir(mode=0o750, parents=True, exist_ok=True) + create_directory(dump_dir) if self.config.task in self.data_collector.tasks_need_tensor_data: dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") - Path(dump_data_dir).mkdir(mode=0o750, exist_ok=True) + create_directory(dump_data_dir) else: dump_data_dir = None @@ -174,22 +187,60 @@ class Service: prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \ module.__class__.__name__ + Const.SEP - pre_forward_hook, forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix) - module.register_forward_hook(forward_hook, with_kwargs=True) + pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 \ + = self.build_hook(BaseScope.Module_Type_Module, prefix) + if torch_version_above_or_equal_2: + module.register_forward_hook(forward_hook, with_kwargs=True) + else: + module.register_full_backward_hook( + self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) + module.register_forward_hook(forward_hook_torch_version_below_2) module.register_full_backward_hook(backward_hook) module.register_forward_pre_hook( self.module_processor.node_hook(prefix + Const.FORWARD, Const.START)) module.register_forward_hook( self.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP)) - module.register_full_backward_pre_hook( - self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START)) - module.register_full_backward_hook( - self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) + if torch_version_above_or_equal_2: + module.register_full_backward_pre_hook( + self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START)) + module.register_full_backward_hook( + self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) if self.config.level in ["mix", "L1", "L2"]: api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)) api_register.api_modularity() if Const.STATISTICS == self.config.task or Const.TENSOR == self.config.task: - remove_dropout() \ No newline at end of file + remove_dropout() + + def attl_init(self): + if self.config.online_run_ut: + from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL + attl_config = ATTLConfig(is_benchmark_device=False, + connect_ip=self.config.host, + connect_port=self.config.port, + nfs_path=self.config.nfs_path, + tls_path=self.config.tls_path) + need_dump = len(self.config.rank) == 0 or self.current_rank in self.config.rank + self.attl = ATTL('npu', attl_config, need_dump=need_dump) + if self.config.nfs_path: + self.attl.upload("start") + + def attl_send(self, api_data): + logger.info(f"tools is dumping api: {api_data.name}, rank: {self.current_rank}") + api_type, _, _ = api_data.name.split(Const.SEP) + if api_type in [Const.DISTRIBUTED]: + logger.info(f"api {api_data.name} is not supported, skip") + return + if self.config.nfs_path: + self.attl.upload(api_data) + else: + self.attl.send(api_data) + + def attl_stop(self): + if self.config.nfs_path: + self.attl.upload("end") + elif self.attl.socket_manager is not None: + logger.info(f"pid: {os.getpid()} finished, start send STOP signal.") + self.attl.socket_manager.send_stop_signal() -- Gitee