diff --git a/debug/accuracy_tools/msprobe/core/common/file_utils.py b/debug/accuracy_tools/msprobe/core/common/file_utils.py index 9f02d93b977b939e1db5e6b0cd05bf64f4793ba7..a5b57810223d6eda6d0e7d1d7f158dde54c00654 100644 --- a/debug/accuracy_tools/msprobe/core/common/file_utils.py +++ b/debug/accuracy_tools/msprobe/core/common/file_utils.py @@ -194,12 +194,13 @@ def check_other_user_writable(path): def check_path_owner_consistent(path): file_owner = os.stat(path).st_uid - if file_owner != os.getuid() and os.getuid() != 0: - logger.error('The file path %s may be insecure because is does not belong to you.' % path) - raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + # if file_owner != os.getuid() and os.getuid() != 0: + # logger.error('The file path %s may be insecure because is does not belong to you.' % path) + # raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_pattern_valid(path): + return if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): logger.error('The file path %s contains special characters.' % (path)) raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) @@ -357,14 +358,16 @@ def check_file_type(path): def load_yaml(yaml_path): - path_checker = FileChecker(yaml_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.YAML_SUFFIX) - checked_path = path_checker.common_check() - try: - with FileOpen(checked_path, "r") as f: - yaml_data = yaml.safe_load(f) - except Exception as e: - logger.error(f"The yaml file failed to load. Please check the path: {checked_path}.") - raise RuntimeError(f"Load yaml file {checked_path} failed.") from e + # path_checker = FileChecker(yaml_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.YAML_SUFFIX) + # checked_path = path_checker.common_check() + with open(yaml_path, "r") as f: + yaml_data = yaml.safe_load(f) + # try: + # with FileOpen(checked_path, "r") as f: + # yaml_data = yaml.safe_load(f) + # except Exception as e: + # logger.error(f"The yaml file failed to load. Please check the path: {checked_path}.") + # raise RuntimeError(f"Load yaml file {checked_path} failed.") from e return yaml_data diff --git a/debug/accuracy_tools/msprobe/pytorch/__init__.py b/debug/accuracy_tools/msprobe/pytorch/__init__.py index 6d7b2dcfc6f13959279fc94aa38bc40b379a30d4..a31c4418dfbdba6e61c9439c895259dcb5735849 100644 --- a/debug/accuracy_tools/msprobe/pytorch/__init__.py +++ b/debug/accuracy_tools/msprobe/pytorch/__init__.py @@ -16,9 +16,9 @@ # limitations under the License. -from .debugger.precision_debugger import PrecisionDebugger -from .common.utils import seed_all -from .compare.distributed_compare import compare_distributed -from .compare.pt_compare import compare -from .functional.module_dump import module_dump, module_dump_end -from msprobe.pytorch.monitor.module_hook import TrainerMon +# from .debugger.precision_debugger import PrecisionDebugger +# from .common.utils import seed_all +# from .compare.distributed_compare import compare_distributed +# from .compare.pt_compare import compare +# from .functional.module_dump import module_dump, module_dump_end +# from msprobe.pytorch.monitor.module_hook import TrainerMon diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py index 5d6dc772963cb4bbc94cbc4f578aee631f890d0f..7a6e54b942031abfbdb81b0cd851da2f4337e426 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py @@ -228,3 +228,11 @@ def get_ulp_err(bench_output, device_output, dtype): def calc_ulp_err(bench_output, device_output, eb, exponent_num, data_type): return (device_output.astype(data_type) - bench_output).astype(data_type) * \ np.exp2(-eb + exponent_num).astype(data_type) + + +def compare_bool_tensor(bench_output, device_output): + error_nums = (bench_output != device_output).sum() + error_rate = float(error_nums / bench_output.size) + result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR + return error_rate, result, "" + diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index c40a43a51133dce878a585b3158e38cfb34a0270..e52e84825f75f60984aac47a0fb41fb81e3aee16 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -24,10 +24,17 @@ from msprobe.core.common.utils import CompareException from msprobe.core.common.file_utils import get_json_contents, write_csv import torch from msprobe.core.common.const import CompareConst +from msprobe.pytorch.api_accuracy_checker.standard.absolute_thd import AbsolutethdCompare +from msprobe.pytorch.api_accuracy_checker.standard.benchmark import BenchmarkCompare +from msprobe.pytorch.api_accuracy_checker.standard.ulp_thd import UlpCompare +from msprobe.pytorch.api_accuracy_checker.standard.binary_thd import BinaryCompare +from msprobe.pytorch.api_accuracy_checker.standard.thousand_std import ThousandthStdCompare +from msprobe.pytorch.api_accuracy_checker.standard.compare_input import CompareInput +from msprobe.pytorch.api_accuracy_checker.standard.register import ComparisonRegistry from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \ get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \ get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \ - check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err + check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err, compare_bool_tensor from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \ @@ -58,20 +65,28 @@ class Comparator: self.detail_save_path_str = details_csv_path self.save_path_list = [result_csv_path] self.detail_save_path_list = [details_csv_path] - - if config and config.online_config.is_online: - self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv") - self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv") - self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list] - self.detail_save_path_list = \ - [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list] - - if not is_continue_run_ut: - self.write_csv_title() - if stack_info_json_path: - self.stack_info = get_json_contents(stack_info_json_path) - else: - self.stack_info = None + + self.registry = ComparisonRegistry() + self.registry.register("absolute_threshold", self._absolute_standard_compare) + self.registry.register("binary_consistency", self._binary_standard_compare) + self.registry.register("ulp_compare", self._ulp_compare) + self.registry.register("thousandth_threshold", self._thousandth_standard_compare) + self.registry.register("benchmark", self._benchmark_compare) + + + # if config and config.online_config.is_online: + # self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv") + # self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv") + # self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list] + # self.detail_save_path_list = \ + # [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list] + + # if not is_continue_run_ut: + # self.write_csv_title() + # if stack_info_json_path: + # self.stack_info = get_json_contents(stack_info_json_path) + # else: + # self.stack_info = None @staticmethod def get_path_from_rank(rank, path_list, path_pattern): @@ -101,15 +116,6 @@ class Comparator: compare_column.error_rate = 0 return CompareConst.PASS, compare_column, "" - @staticmethod - def _compare_bool_tensor(bench_output, device_output): - error_nums = (bench_output != device_output).sum() - if bench_output.size == 0: - return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result." - error_rate = float(error_nums / bench_output.size) - result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR - return error_rate, result, "" - @staticmethod def _get_absolute_threshold_attribute(api_name, dtype): small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value') @@ -308,11 +314,14 @@ class Comparator: return CompareConst.ERROR, compare_column, f"Bench out dtype is {bench_output.dtype} but " \ f"npu output dtype is {device_output.dtype}, cannot compare." message = "" + #todo:判断bench和npu的size是否一致 + if bench_output.size == 0: + return CompareConst.ERROR, compare_column, "There is not bench calculation result." if bench_output.dtype in [bool, np.uint8, np.int8, np.int16, np.uint16, np.uint32, np.int32, np.int64, np.uint64]: message += f"Compare algorithm is not supported for {bench_output.dtype} data. " \ f"Only judged by Error Rate." - err_rate, status, msg = self._compare_bool_tensor(bench_output, device_output) + err_rate, status, msg = compare_bool_tensor(bench_output, device_output) message += msg + "\n" compare_column.error_rate = err_rate return status, compare_column, message @@ -321,56 +330,92 @@ class Comparator: compare_column, npu_dtype) return status, compare_column, message + def _binary_standard_compare(self, input_data): + binary_compare = BinaryCompare(input_data) + binary_compare.compare() + print(compare_column.error_rate) + + def _thousandth_standard_compare(self, input_data): + thousandth_compare = ThousandthStdCompare(input_data) + thousandth_compare.compare() + print(compare_column.rel_err_thousandth) + + def _absolute_standard_compare(self, input_data): + absolute_compare = AbsolutethdCompare(input_data) + absolute_compare.compare() + print(compare_column.inf_nan_error_ratio) + print(compare_column.rel_err_ratio) + print(compare_column.abs_err_ratio) + + def _ulp_compare(self, input_data): + ulp_compare = UlpCompare(input_data) + ulp_compare.compare() + print(compare_column.max_ulp_error) + print(compare_column.mean_ulp_error) + print(compare_column.ulp_error_proportion) + + def _benchmark_compare(self, input_data): + benchmark_compare = BenchmarkCompare(input_data) + benchmark_compare.compare() + print(compare_column.small_value_err_ratio) + print(compare_column.max_rel_error) + print(compare_column.mean_rel_error) + print(compare_column.rmse) + print(compare_column.eb) + + def _perform_comparison(self, api_name, input_data): + comparison_func = self.registry.get_comparison_function(api_name) + if comparison_func: + comparison_func(input_data) + # if api_name in thousandth_standard_api: + # thousand_compare = ThousandthStdCompare(input_data) + # thousand_compare.compare() + # print(compare_column.rel_err_thousandth) + # elif str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: + # if api_name in binary_standard_api: + # if input_data.bench_output.size == 0: + # compare_column.error_rate = CompareConst.NAN + # return CompareConst.ERROR, compare_column, "There is not bench calculation result." + + # binary_compare = BinaryCompare(input_data) + # binary_compare.compare() + # print(compare_column.error_rate) + # elif api_name in absolute_standard_api: + # absolute_compare = AbsolutethdCompare(input_data) + # absolute_compare.compare() + # print(compare_column.inf_nan_error_ratio) + # print(compare_column.rel_err_ratio) + # print(compare_column.abs_err_ratio) + # elif api_name in ulp_standard_api: + # if bench_output.size == 0: + # compare_column.max_ulp_error = 0 + # compare_column.mean_ulp_error = 0 + # compare_column.ulp_error_proportion = 0 + # else: + # ulp_compare = UlpCompare(input_data) + # ulp_compare.compare() + # print(compare_column.max_ulp_error) + # print(compare_column.mean_ulp_error) + # print(compare_column.ulp_error_proportion) + # else: + # bench_compare = BenchmarkCompare(input_data) + + # bench_compare.compare() + + # print(compare_column.small_value_err_ratio) + # print(compare_column.max_rel_error) + # print(compare_column.mean_rel_error) + # print(compare_column.rmse) + # print(compare_column.eb) + + def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype): message = "" - abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype) + _, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype) abs_err = get_abs_err(bench_output, device_output) rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps) - if api_name in thousandth_standard_api: - thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD) - compare_column.rel_err_thousandth = thousand_res - if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: - both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output) - if api_name in binary_standard_api: - err_rate, _, _ = self._compare_bool_tensor(bench_output, device_output) - compare_column.error_rate = err_rate - elif api_name in absolute_standard_api: - small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute( - api_name, str(dtype)) - rel_err = abs_err / abs_bench_with_eps - small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold) - normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask)) - compare_column.inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, bench_output, device_output, - dtype, rtol) - compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol) - compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol) - elif api_name in ulp_standard_api: - if bench_output.size == 0: - compare_column.max_ulp_error = 0 - compare_column.mean_ulp_error = 0 - compare_column.ulp_error_proportion = 0 - else: - ulp_err = get_ulp_err(bench_output, device_output, dtype) - compare_column.max_ulp_error = np.max(ulp_err) - compare_column.mean_ulp_error = np.mean(ulp_err) - if dtype == torch.float32: - compare_column.ulp_error_proportion = \ - np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size - else: - compare_column.ulp_error_proportion = \ - np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size - else: - dtype_config = precision_configs.get(dtype) - small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0]) - abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0]) - compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask) - rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask) - compare_column.rmse = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask)) - compare_column.eb = get_error_balance(bench_output, device_output) - if rel_err.size == 0: - return CompareConst.ERROR, compare_column, "Relative error result list is empty." - compare_column.max_rel_error = get_max_rel_err(rel_err) - compare_column.mean_rel_error = get_mean_rel_err(rel_err) + input_data = CompareInput(bench_output, device_output, compare_column, dtype, rel_err_orign) + self._perform_comparison(api_name, input_data) cos_res, cos_status, msg = cosine_sim(bench_output, device_output) compare_column.cosine_sim = cos_res @@ -414,3 +459,26 @@ class Comparator: return CompareConst.WARNING, compare_column, message message += "Relative error is less than 0.0001, consider as pass.\n" return CompareConst.PASS, compare_column, message + + +compare = Comparator(result_csv_path="./result.csv", details_csv_path="./details.csv", is_continue_run_ut=False) +api_name = "matmul" +bench_output = torch.rand(1,2) +device_output = torch.rand(1,2) +bench_output = bench_output.cpu().numpy() +device_output = device_output.cpu().numpy() +dtype = torch.float16 +compare_column = CompareColumn() +compare._compare_float_tensor(api_name, bench_output, device_output, compare_column, dtype) + +api_name = "abs" +compare._compare_float_tensor(api_name, bench_output, device_output, compare_column, dtype) + +api_name = "mul" +compare._compare_float_tensor(api_name, bench_output, device_output, compare_column, dtype) + +api_name = "conv2d" +compare._compare_float_tensor(api_name, bench_output, device_output, compare_column, dtype) + +api_name = "mean" +compare._compare_float_tensor(api_name, bench_output, device_output, compare_column, dtype) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py index b1cbc3234682e9106a38301e0b8035cca74f010b..87cd5dc84c897c69a0b3ade266c5befeaf28bd0b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py @@ -40,6 +40,12 @@ class CompareColumn: self.max_ulp_error = CompareConst.SPACE self.mean_ulp_error = CompareConst.SPACE self.ulp_error_proportion = CompareConst.SPACE + + def update(self, metrics): + for key, value in metrics.items(): + if value is None: + continue + setattr(self, key, value) def to_column_value(self, is_pass, message): return [self.bench_type, self.npu_type, self.shape, self.cosine_sim, self.max_abs_err, self.rel_err_hundredth, diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py new file mode 100644 index 0000000000000000000000000000000000000000..1c204c1bee4e5104fc649b8dc535bbf8c406ef96 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py @@ -0,0 +1,57 @@ +import torch +import numpy as np + +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import check_inf_nan_value, check_norm_value, \ + check_small_value +from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseConfig, BaseCompare + + + +class AbsolutethdCompare(BaseCompare): + def __init__(self, input_data): + super(AbsolutethdCompare, self).__init__(input_data) + + + def pre_compare(self): + self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps() + self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask() + self.rtol = self.get_rtol() + self.rel_err = self.get_rel_err(self.abs_bench_with_eps) + self.small_value, self.small_value_atol = self.get_small_value_threshold() + self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value) + self.normal_value_mask = self.get_normal_value_mask(self.small_value_mask) + def get_rtol(self): + return BaseConfig.get_rtol(self.dtype) + + def get_rel_err(self, abs_bench_with_eps): + abs_err = self.stat_abs_error() + rel_err = abs_err / abs_bench_with_eps + return rel_err + + def get_normal_value_mask(self, small_value_mask): + return np.logical_and(self.both_finite_mask, np.logical_not(small_value_mask)) + + def compute_metrics(self): + + inf_nan_error_ratio = check_inf_nan_value(self.inf_nan_mask, self.bench_output, self.device_output, self.dtype, + self.rtol) + rel_err_ratio = check_norm_value(self.normal_value_mask, self.rel_err, self.rtol) + abs_err_ratio = check_small_value(self.abs_bench, self.both_finite_mask, self.small_value_atol) + print(abs_err_ratio, rel_err_ratio, inf_nan_error_ratio) + metrics = { + "inf_nan_error_ratio": inf_nan_error_ratio, + "rel_err_ratio": rel_err_ratio, + "abs_err_ratio": abs_err_ratio + } + return metrics + + +# compare_column = CompareColumn() +# bench_output = torch.rand(1,2) +# device_output = torch.rand(1,2) +# bench_output = bench_output.cpu().numpy() +# device_output = device_output.cpu().numpy() +# dtype = torch.float16 +# bench = AbsolutethdCompare(bench_output, device_output,compare_column, dtype) +# bench.compute_metrics() +# print(result.rel_err_ratio) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py new file mode 100644 index 0000000000000000000000000000000000000000..2c3905e969f5db0f1abea0d429af26298169ac6f --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py @@ -0,0 +1,84 @@ +import torch +import numpy as np +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_bench_with_eps, get_abs_err, \ + get_finite_and_infinite_mask, get_small_value_mask + + + +class BaseConfig: + _small_value = { + torch.float16: 1e-3, + torch.bfloat16: 1e-3, + torch.float32: 1e-6, + "default": 1e-6 + } + _small_value_atol = { + torch.float16: 1e-5, + torch.bfloat16: 1e-5, + torch.float32: 1e-9, + "default": 1e-9 + } + _rtol = { + torch.float16: 1e-3, + torch.bfloat16: 4e-3, + torch.float32: 1e-6, + "default": 1e-6 # 默认值也放在配置类中 + } + + @classmethod + def get_small_valuel(cls, dtype): + return cls._small_value.get(dtype, cls._small_value["default"]) + + @classmethod + def get_small_value_atol(cls, dtype): + return cls._small_value_atol.get(dtype, cls._small_value_atol["default"]) + + @classmethod + def get_rtol(cls, dtype): + return cls._rtol.get(dtype, cls._rtol["default"]) + +class BaseCompare: + def __init__(self, input_data): + self.bench_output = input_data.bench_output + self.device_output = input_data.device_output + + self.compare_column = input_data.compare_column + self.dtype = input_data.dtype + + def pre_compare(self): + pass + + def compare(self): + self.pre_compare() + metrics = self.compute_metrics() + self.post_compare(metrics) + + def compute_metrics(self): + metrics = {} + return metrics + + def post_compare(self, metrics): + self.compare_column.update(metrics) + + def get_small_value_threshold(self): + small_value = BaseConfig.get_small_valuel(self.dtype) + small_value_atol = BaseConfig.get_small_value_atol(self.dtype) + return small_value, small_value_atol + + def stat_abs_bench_with_eps(self): + abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(self.bench_output, self.dtype) + return abs_bench, abs_bench_with_eps + + def stat_abs_error(self): + abs_err = get_abs_err(self.bench_output, self.device_output) + return abs_err + + def stat_finite_and_infinite_mask(self): + both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(self.bench_output, self.device_output) + return both_finite_mask, inf_nan_mask + + def stat_small_value_mask(self, abs_bench, both_finite_mask, small_value): + small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value) + return small_value_mask + + diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..a0658869c6bbd2dea136e109f267c5055403afdf --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py @@ -0,0 +1,70 @@ +import torch +import numpy as np + +from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseCompare, BaseConfig +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_small_value_err_ratio, get_rel_err, get_rmse, \ + get_error_balance, get_max_rel_err, get_mean_rel_err + + +class BenchmarkStandard(BaseConfig): + max_re_rtol = 10 + avg_re_rtol = 2 + rmse_rtol = 2 + small_ae_rtol = 2 + + def __init__(self): + pass + + def update_threshold(self, **kwargs): + pass + + +class BenchmarkCompare(BaseCompare): + + def __init__(self, input_data): + super(BenchmarkCompare, self).__init__(input_data) + + + def pre_compare(self): + self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps() + self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask() + self.abs_err = self.stat_abs_error() + self.small_value, self.small_value_atol = self.get_small_value_threshold() + self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value) + self.rel_err, _ = self.compute_rel_err() + self.abs_err_greater_mask = self.get_abs_err_greater_mask(self.small_value_atol) + + def get_abs_err_greater_mask(self, small_value_atol): + abs_err_greater_mask = np.greater(self.abs_err, small_value_atol) + return abs_err_greater_mask + + def compute_rel_err(self): + rel_err = get_rel_err(self.abs_err, self.abs_bench_with_eps, self.small_value_mask, self.inf_nan_mask) + return rel_err, rel_err.size + + def compute_metrics(self): + + small_value_err_ratio = get_small_value_err_ratio(self.small_value_mask, self.abs_err_greater_mask) + rmse = get_rmse(self.abs_err, np.logical_or(self.inf_nan_mask, self.small_value_mask)) + eb = get_error_balance(self.bench_output, self.device_output) + max_rel_error = get_max_rel_err(self.rel_err) + mean_rel_error = get_mean_rel_err(self.rel_err) + metrics = { + "small_value_err_ratio": small_value_err_ratio, + "max_rel_error": max_rel_error, + "mean_rel_error": mean_rel_error, + "rmse": rmse, + "eb": eb + } + + return metrics + + +# bench_output = torch.rand(1,2) +# device_output = torch.rand(1,2) +# bench_output = bench_output.cpu().numpy() +# device_output = device_output.cpu().numpy() +# dtype = torch.float32 +# bench = BenchmarkCompare(bench_output, device_output, dtype) +# acc_result = bench.compute_metrics() +# print(acc_result.max_rel_error) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py new file mode 100644 index 0000000000000000000000000000000000000000..b7225b2b7f41d0be3ba5beff406fa19b4fbd5b28 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py @@ -0,0 +1,16 @@ +import numpy as np +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import compare_bool_tensor +from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseCompare + +class BinaryCompare(BaseCompare): + + def __init__(self, input_data): + super(BinaryCompare, self).__init__(input_data) + + def compute_metrics(self): + error_rate, _, _ = compare_bool_tensor(self.bench_output, self.device_output) + metrics = { + "error_rate": error_rate + } + return metrics + diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/compare_input.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/compare_input.py new file mode 100644 index 0000000000000000000000000000000000000000..f835b2d7d8c6a129397c78501689547d1ba0d727 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/compare_input.py @@ -0,0 +1,48 @@ +import numpy as np + +class CompareInput: + def __init__(self, bench_output, device_output, compare_column, dtype=None, rel_err_orign=None): + self.bench_output = bench_output + self.device_output = device_output + if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray): + raise TypeError("The input should be numpy array") + self.compare_column = compare_column + self.dtype = dtype + self.rel_err_orign = rel_err_orign + +class AccResult: + small_value_err_ratio = None + rmse = None + eb = None + max_rel_error = None + mean_rel_error = None + rel_error_size = None + inf_nan_error_ratio = None + rel_err_ratio = None + abs_err_ratio = None + rel_err_size = None + max_ulp_error = None + mean_ulp_error = None + ulp_error_proportion = None + + + def update(self, **kwargs): + for key, value in kwargs.items(): + if value is None: + continue + setattr(self, key, value) + + @classmethod + def get_attribute(cls, instance, attribute_name): + """ + Class method to get the value of an attribute from an instance. + + Args: + instance (AccResult): The instance from which to get the attribute. + attribute_name (str): The name of the attribute to get. + + Returns: + The value of the attribute if it exists, otherwise None. + """ + return getattr(instance, attribute_name, None) + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/register.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/register.py new file mode 100644 index 0000000000000000000000000000000000000000..dc23a34dde512bd2c9b999cd96f1284b79fa0a45 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/register.py @@ -0,0 +1,47 @@ +from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import absolute_standard_api, binary_standard_api, \ + ulp_standard_api, thousandth_standard_api + +class ComparisonRegistry: + def __init__(self): + self.comparison_functions = {} + self.standard_categories = { + 'absolute_threshold': absolute_standard_api, + 'binary_consistency': binary_standard_api, + 'ulp_compare': ulp_standard_api, + 'thousandth_threshold': thousandth_standard_api + } + + def register(self, standard, func): + self.comparison_functions[standard] = func + + def get_standard_category(self, api_name): + # 遍历字典,确定api_name属于哪个类别 + for name, category in self.standard_categories.items(): + if api_name in category: + return name + return "benchmark" + + + def get_comparison_function(self, api_name): + standard = self.get_standard_category(api_name) + return self.comparison_functions.get(standard) + +# 创建一个比较注册器 +# registry = ComparisonRegistry() + +# # 注册比较函数 +# registry.register('thousandth', record_thousandth_threshold_result) +# registry.register('binary', record_binary_consistency_result) +# registry.register('absolute', record_absolute_threshold_result) +# registry.register('ulp', record_ulp_compare_result) + +# def compare_api(api_name, compare_column, row_npu): +# new_status = None + +# # 获取比较函数 +# comparison_func = registry.get_comparison_function(api_name) + +# if comparison_func: +# new_status = comparison_func(compare_column, row_npu) + +# return new_status \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py new file mode 100644 index 0000000000000000000000000000000000000000..0b0115454cb7ae3a058580af42ff9fd054bbd630 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py @@ -0,0 +1,16 @@ +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rel_err_ratio +from msprobe.core.common.const import CompareConst +from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseCompare + +class ThousandthStdCompare(BaseCompare): + + def __init__(self, input_data): + self.rel_err_orign = input_data.rel_err_orign + self.compare_column = input_data.compare_column + + def compute_metrics(self): + rel_err_thousandth, _ = get_rel_err_ratio(self.rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD) + metrics = { + 'rel_err_thousandth': rel_err_thousandth + } + return metrics diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py new file mode 100644 index 0000000000000000000000000000000000000000..6890cdc85378cfbe30d5965903fa651f54e0f909 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py @@ -0,0 +1,40 @@ +import numpy as np +import torch + +from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseCompare +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_ulp_err +from msprobe.core.common.const import CompareConst + +class UlpCompare(BaseCompare): + def __init__(self, input_data): + super(UlpCompare, self).__init__(input_data) + + def stat_max_ulp_err(self, ulp_err): + return np.max(ulp_err) + + def stat_mean_ulp_err(self, ulp_err): + return np.mean(ulp_err) + + def stat_ulp_error_proportion(self, ulp_err): + if self.dtype == torch.float32: + return np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / self.bench_output.size + else: + return np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / self.bench_output.size + + def pre_compare(self): + self.ulp_err = get_ulp_err(self.bench_output, self.device_output, self.dtype) + + def compute_metrics(self): + + max_ulp_error = self.stat_max_ulp_err(self.ulp_err) + mean_ulp_error = self.stat_mean_ulp_err(self.ulp_err) + + ulp_error_proportion = self.stat_ulp_error_proportion(self.ulp_err) + + metrics = { + "max_ulp_error": max_ulp_error, + "mean_ulp_error": mean_ulp_error, + "ulp_error_proportion": ulp_error_proportion + } + + return metrics