diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 931dcae9f246d1ea264a915851ef0d793eb87d83..5501538c50c7ffd5c2972e575e7cd3c7e9e65592 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -1,86 +1,95 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo -from api_accuracy_checker.dump.info_dump import write_api_info_json, initialize_output_json -from api_accuracy_checker.common.utils import print_error_log, CompareException -from api_accuracy_checker.hook_module.register_hook import initialize_hook -from api_accuracy_checker.common.config import msCheckerConfig - - -def set_dump_switch(switch): - if switch not in ["ON", "OFF"]: - print_error_log("Please set switch with 'ON' or 'OFF'.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - if switch == "ON": - initialize_hook(pretest_hook) - initialize_output_json() - DumpUtil.set_dump_switch(switch) - - -class DumpUtil(object): - dump_switch = None - call_num = 0 - - @staticmethod - def set_dump_switch(switch): - DumpUtil.dump_switch = switch - - @staticmethod - def get_dump_switch(): - return DumpUtil.dump_switch == "ON" - - @staticmethod - def incr_iter_num_maybe_exit(): - if DumpUtil.call_num in msCheckerConfig.target_iter: - set_dump_switch("ON") - elif DumpUtil.call_num > max(msCheckerConfig.target_iter): - raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) - else: - set_dump_switch("OFF") - DumpUtil.call_num += 1 - - -class DumpConst: - delimiter = '*' - forward = 'forward' - backward = 'backward' - - -def pretest_info_dump(name, out_feat, module, phase): - if not DumpUtil.get_dump_switch(): - return - if phase == DumpConst.forward: - api_info = ForwardAPIInfo(name, module.input_args, module.input_kwargs) - elif phase == DumpConst.backward: - api_info = BackwardAPIInfo(name, out_feat) - else: - msg = "Unexpected training phase {}.".format(phase) - print_error_log(msg) - raise NotImplementedError(msg) - - write_api_info_json(api_info) - - -def pretest_hook(name, phase): - def pretest_info_dump_hook(module, in_feat, out_feat): - pretest_info_dump(name, out_feat, module, phase) - if hasattr(module, "input_args"): - del module.input_args - if hasattr(module, "input_kwargs"): - del module.input_kwargs - return pretest_info_dump_hook +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import torch.distributed as dist + +from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo +from api_accuracy_checker.dump.info_dump import write_api_info_json, initialize_output_json +from api_accuracy_checker.common.utils import print_error_log, CompareException +from api_accuracy_checker.hook_module.register_hook import initialize_hook +from api_accuracy_checker.common.config import msCheckerConfig + + +def set_dump_switch(switch): + if switch not in ["ON", "OFF"]: + print_error_log("Please set switch with 'ON' or 'OFF'.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + if switch == "ON": + initialize_hook(pretest_hook) + initialize_output_json() + DumpUtil.set_dump_switch(switch) + + +class DumpUtil(object): + dump_switch = None + call_num = 0 + + @staticmethod + def set_dump_switch(switch): + DumpUtil.dump_switch = switch + + @staticmethod + def get_dump_switch(): + return DumpUtil.dump_switch == "ON" + + @staticmethod + def incr_iter_num_maybe_exit(): + if DumpUtil.call_num in msCheckerConfig.target_iter: + set_dump_switch("ON") + elif DumpUtil.call_num > max(msCheckerConfig.target_iter): + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) + else: + set_dump_switch("OFF") + DumpUtil.call_num += 1 + + +class DumpConst: + delimiter = '*' + forward = 'forward' + backward = 'backward' + + +def pretest_info_dump(name, out_feat, module, phase): + if not DumpUtil.get_dump_switch(): + return + if phase == DumpConst.forward: + if "Distributed" in name: + if module.input_kwargs.get("op"): + module.input_kwargs["op"] = module.input_kwargs["op"].name + if module.input_kwargs.get("group"): + if isinstance(module.input_kwargs["group"], dist.distributed_c10d.ProcessGroup): + module.input_kwargs["group"] = module.input_kwargs["group"].size() + else: + module.input_kwargs["group"] = dist.distributed_c10d.get_world_size() + api_info = ForwardAPIInfo(name, module.input_args, module.input_kwargs) + elif phase == DumpConst.backward: + api_info = BackwardAPIInfo(name, out_feat) + else: + msg = "Unexpected training phase {}.".format(phase) + print_error_log(msg) + raise NotImplementedError(msg) + + write_api_info_json(api_info) + + +def pretest_hook(name, phase): + def pretest_info_dump_hook(module, in_feat, out_feat): + pretest_info_dump(name, out_feat, module, phase) + if hasattr(module, "input_args"): + del module.input_args + if hasattr(module, "input_kwargs"): + del module.input_kwargs + return pretest_info_dump_hook diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/register_hook.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/register_hook.py index b355e029b6b74e2accc9241b42deebe31cb8e5ca..4f554bc72a573bf4f5feb0728001862449fef45d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/register_hook.py +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/register_hook.py @@ -1,37 +1,62 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -import torch - -from api_accuracy_checker.hook_module import wrap_torch, wrap_functional, wrap_tensor - - -def initialize_hook(hook): - wrap_tensor.wrap_tensor_ops_and_bind(hook) - for attr_name in dir(wrap_tensor.HOOKTensor): - if attr_name.startswith("wrap_"): - setattr(torch.Tensor, attr_name[5:], getattr(wrap_tensor.HOOKTensor, attr_name)) - - wrap_torch.wrap_torch_ops_and_bind(hook) - for attr_name in dir(wrap_torch.HOOKTorchOP): - if attr_name.startswith("wrap_"): - setattr(torch, attr_name[5:], getattr(wrap_torch.HOOKTorchOP, attr_name)) - - wrap_functional.wrap_functional_ops_and_bind(hook) - for attr_name in dir(wrap_functional.HOOKFunctionalOP): - if attr_name.startswith("wrap_"): - setattr(torch.nn.functional, attr_name[5:], getattr(wrap_functional.HOOKFunctionalOP, attr_name)) - +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import torch +import torch.distributed as dist + +from api_accuracy_checker.hook_module import wrap_torch, wrap_functional, wrap_tensor, wrap_distributed +from api_accuracy_checker.common.utils import torch_without_guard_version + +try: + import torch_npu +except ImportError: + is_gpu = True +else: + is_gpu = False + from . import wrap_npu_custom + + +def initialize_hook(hook): + wrap_tensor.wrap_tensor_ops_and_bind(hook) + for attr_name in dir(wrap_tensor.HOOKTensor): + if attr_name.startswith("wrap_"): + setattr(torch.Tensor, attr_name[5:], getattr(wrap_tensor.HOOKTensor, attr_name)) + + wrap_torch.wrap_torch_ops_and_bind(hook) + for attr_name in dir(wrap_torch.HOOKTorchOP): + if attr_name.startswith("wrap_"): + setattr(torch, attr_name[5:], getattr(wrap_torch.HOOKTorchOP, attr_name)) + + wrap_functional.wrap_functional_ops_and_bind(hook) + for attr_name in dir(wrap_functional.HOOKFunctionalOP): + if attr_name.startswith("wrap_"): + setattr(torch.nn.functional, attr_name[5:], getattr(wrap_functional.HOOKFunctionalOP, attr_name)) + + wrap_distributed.wrap_distributed_ops_and_bind(hook) + for attr_name in dir(wrap_distributed.HOOKDistributedOP): + if attr_name.startswith("wrap_"): + setattr(dist, attr_name[5:], getattr(wrap_distributed.HOOKDistributedOP, attr_name)) + setattr(dist.distributed_c10d, attr_name[5:], getattr(wrap_distributed.HOOKDistributedOP, attr_name)) + if not is_gpu and not torch_without_guard_version: + setattr(torch_npu.distributed, attr_name[5:], getattr(wrap_distributed.HOOKDistributedOP, attr_name)) + setattr(torch_npu.distributed.distributed_c10d, attr_name[5:], + getattr(wrap_distributed.HOOKDistributedOP, attr_name)) + + if not is_gpu: + wrap_npu_custom.wrap_npu_ops_and_bind(hook) + for attr_name in dir(wrap_npu_custom.HOOKNpuOP): + if attr_name.startswith("wrap_"): + setattr(torch_npu, attr_name[5:], getattr(wrap_npu_custom.HOOKNpuOP, attr_name)) diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/api_accuracy_checker/hook_module/support_wrap_ops.yaml index c7ed0a1f81cf7b6b2e17ce0e6c37965567f5e42a..ae5eb4b2ec976a3de8a4f985d5856ece2a909698 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/support_wrap_ops.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/support_wrap_ops.yaml @@ -1,999 +1,1066 @@ -# Copyright (c) 2023 Huawei Technologies Co., Ltd -# All rights reserved. -# -# Licensed under the BSD 3-Clause License (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://opensource.org/licenses/BSD-3-Clause -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# List of ops that register hooks - -functional: - - conv1d - - conv2d - - conv3d - - conv_transpose1d - - conv_transpose2d - - conv_transpose3d - - conv_tbc - - avg_pool1d - - avg_pool2d - - avg_pool3d - - fractional_max_pool2d_with_indices - - fractional_max_pool2d - - fractional_max_pool3d_with_indices - - fractional_max_pool3d - - max_pool1d_with_indices - - max_pool1d - - max_pool2d_with_indices - - max_pool2d - - max_pool3d_with_indices - - max_pool3d - - max_unpool1d - - max_unpool2d - - max_unpool3d - - lp_pool2d - - lp_pool1d - - adaptive_max_pool1d_with_indices - - adaptive_max_pool1d - - adaptive_max_pool2d_with_indices - - adaptive_max_pool2d - - adaptive_max_pool3d_with_indices - - adaptive_max_pool3d - - adaptive_avg_pool1d - - adaptive_avg_pool2d - - adaptive_avg_pool3d - - dropout - - alpha_dropout - - dropout2d - - dropout3d - - feature_alpha_dropout - - threshold - - threshold_ - - relu - - relu_ - - glu - - hardtanh - - hardtanh_ - - relu6 - - elu - - elu_ - - selu - - selu_ - - celu - - celu_ - - leaky_relu - - leaky_relu_ - - prelu - - rrelu - - rrelu_ - - logsigmoid - - gelu - - hardshrink - - tanhshrink - - softsign - - softplus - - softmin - - softmax - - gumbel_softmax - - log_softmax - - softshrink - - tanh - - sigmoid - - hardsigmoid - - linear - - bilinear - - silu - - hardswish - - embedding - - embedding_bag - - batch_norm - - instance_norm - - layer_norm - - group_norm - - local_response_norm - - ctc_loss - - nll_loss - - poisson_nll_loss - - gaussian_nll_loss - - kl_div - - cross_entropy - - binary_cross_entropy - - binary_cross_entropy_with_logits - - smooth_l1_loss - - l1_loss - - mse_loss - - margin_ranking_loss - - hinge_embedding_loss - - multilabel_margin_loss - - soft_margin_loss - - multilabel_soft_margin_loss - - cosine_embedding_loss - - multi_margin_loss - - pixel_shuffle - - pixel_unshuffle - - channel_shuffle - - upsample - - interpolate - - upsample_nearest - - upsample_bilinear - - grid_sample - - affine_grid - - pad - - pairwise_distance - - pdist - - cosine_similarity - - one_hot - - triplet_margin_loss - - triplet_margin_with_distance_loss - - normalize - - unfold - - fold - - multi_head_attention_forward - -tensor: - - __add__ - - __and__ - - __bool__ - - __div__ - - __eq__ - - __ge__ - - __gt__ - - __iadd__ - - __iand__ - - __idiv__ - - __ifloordiv__ - - __ilshift__ - - __imod__ - - __imul__ - - __ior__ - - __irshift__ - - __isub__ - - __ixor__ - - __lshift__ - - __matmul__ - - __mod__ - - __mul__ - - __nonzero__ - - __or__ - - __radd__ - - __rmul__ - - __rshift__ - - __sub__ - - __truediv__ - - __xor__ - - abs - - abs_ - - absolute - - absolute_ - - acos - - acos_ - - acosh - - acosh_ - - add - - add_ - - addbmm - - addbmm_ - - addcdiv - - addcdiv_ - - addcmul - - addcmul_ - - addmm - - addmm_ - - addmv - - addmv_ - - addr - - addr_ - - align_as - - align_to - - all - - allclose - - amax - - amin - - angle - - any - - arccos - - arccos_ - - arccosh - - arccosh_ - - arcsin - - arcsin_ - - arcsinh - - arcsinh_ - - arctan - - arctan_ - - arctanh - - arctanh_ - - argmax - - argmin - - argsort - - asin - - asin_ - - asinh - - asinh_ - - atan - - atan2 - - atan2_ - - atan_ - - atanh - - atanh_ - - baddbmm - - baddbmm_ - - bernoulli - - bernoulli_ - - bincount - - bitwise_and - - bitwise_and_ - - bitwise_not - - bitwise_not_ - - bitwise_or - - bitwise_or_ - - bitwise_xor - - bitwise_xor_ - - bmm - - broadcast_to - - cauchy_ - - ceil - - ceil_ - - cholesky - - chunk - - clamp - - cholesky_solve - - cholesky_inverse - - clamp_ - - clamp_max - - clamp_max_ - - clip - - clamp_min - - clamp_min_ - - clip_ - - copysign - - copysign_ - - cos - - cos_ - - cosh - - cosh_ - - count_nonzero - - cummax - - cummin - - cumprod - - cumprod_ - - cumsum - - cumsum_ - - deg2rad - - deg2rad_ - - det - - diag - - diag_embed - - diagflat - - diagonal - - diff - - dist - - digamma - - digamma_ - - div - - div_ - - divide - - divide_ - - dot - - eig - - eq - - eq_ - - erf - - equal - - erf_ - - erfc - - erfc_ - - erfinv - - erfinv_ - - exp - - exp2 - - exp2_ - - expm1 - - exp_ - - expm1_ - - exponential_ - - fill_ - - fix - - fill_diagonal_ - - fix_ - - flip - - fliplr - - flatten - - flipud - - float_power - - float_power_ - - floor - - floor_ - - floor_divide - - floor_divide_ - - fmax - - fmin - - fmod - - fmod_ - - frac - - frac_ - - gather - - gcd - - gcd_ - - ge - - ge_ - - geometric_ - - geqrf - - ger - - greater - - greater_ - - gt - - gt_ - - greater_equal - - greater_equal_ - - hardshrink - - heaviside - - heaviside_ - - histc - - hypot - - hypot_ - - igamma - - igamma_ - - igammac - - igammac_ - - index_add - - index_add_ - - inverse - - index_copy - - index_copy_ - - index_fill - - index_fill_ - - index_put - - index_put_ - - inner - - index_select - - isclose - - isfinite - - isinf - - isnan - - isneginf - - isposinf - - isreal - - kron - - kthvalue - - lcm - - lcm_ - - ldexp - - ldexp_ - - le - - le_ - - lerp - - lerp_ - - where - - less - - less_ - - less_equal - - less_equal_ - - lgamma - - lgamma_ - - log - - log10 - - log10_ - - log1p - - log1p_ - - log2 - - log2_ - - log_ - - log_normal_ - - log_softmax - - logcumsumexp - - logdet - - logaddexp - - logaddexp2 - - logical_and - - logical_and_ - - logical_not - - logit - - logical_not_ - - logical_or - - logical_or_ - - logical_xor - - logical_xor_ - - logit_ - - logsumexp - - lstsq - - lt - - lt_ - - lu_solve - - map2_ - - map_ - - masked_fill - - matmul - - masked_fill_ - - masked_scatter - - masked_scatter_ - - masked_select - - matrix_exp - - max - - maximum - - mean - - matrix_power - - median - - min - - minimum - - mm - - mode - - msort - - mul - - mul_ - - multinomial - - multiply - - multiply_ - - mv - - mvlgamma - - mvlgamma_ - - nansum - - narrow - - narrow_copy - - ne - - ne_ - - neg - - neg_ - - negative - - negative_ - - nonzero - - normal_ - - not_equal - - not_equal_ - - permute - - pinverse - - polygamma - - pow - - pow_ - - polygamma_ - - prelu - - prod - - put_ - - rad2deg - - rad2deg_ - - ravel - - real - - reciprocal - - reciprocal_ - - relu - - relu_ - - remainder - - repeat_interleave - - reshape - - remainder_ - - renorm - - renorm_ - - repeat - - reshape_as - - resize_ - - resize_as_ - - roll - - rot90 - - round - - round_ - - rsqrt - - rsqrt_ - - scatter - - scatter_ - - scatter_add - - scatter_add_ - - select - - sgn - - sgn_ - - sigmoid - - sigmoid_ - - sign - - sign_ - - signbit - - sin - - sin_ - - sinc - - sinc_ - - sinh - - sinh_ - - slogdet - - smm - - softmax - - solve - - sort - - split_with_sizes - - sqrt - - sqrt_ - - square - - square_ - - squeeze - - squeeze_ - - sspaddmm - - std - - sub - - sub_ - - sum - - sum_to_size - - svd - - symeig - - t - - t_ - - take - - tan - - tan_ - - tanh - - tanh_ - - tensor_split - - tile - - topk - - transpose - - transpose_ - - triangular_solve - - tril - - tril_ - - triu - - true_divide - - triu_ - - true_divide_ - - trunc - - trunc_ - - type_as - - unbind - - unflatten - - unfold - - unsafe_chunk - - unsqueeze - - unsafe_split - - unsafe_split_with_sizes - - var - - vdot - - unsqueeze_ - - view_as - - xlogy - - xlogy_ - -torch: - - _adaptive_avg_pool2d - - _add_relu - - _add_relu_ - - _aminmax - - _batch_norm_impl_index - - _convolution - - abs - - abs_ - - absolute - - acos - - acos_ - - acosh - - acosh_ - - adaptive_avg_pool1d - - adaptive_max_pool1d - - add - - addbmm - - addcdiv - - addcmul - - addmm - - addmv - - addmv_ - - addr - - amax - - affine_grid_generator - - align_tensors - - all - - alpha_dropout - - amin - - alpha_dropout_ - - angle - - any - - arange - - arccos - - arccos_ - - arccosh - - arccosh_ - - arcsin - - arcsin_ - - arcsinh - - arcsinh_ - - arctan - - arctan_ - - arctanh - - arctanh_ - - argmax - - argmin - - argsort - - asin - - asin_ - - asinh - - asinh_ - - atan - - atan2 - - atan_ - - atanh - - atanh_ - - atleast_1d - - atleast_2d - - atleast_3d - - avg_pool1d - - baddbmm - - bartlett_window - - batch_norm_backward_elemt - - batch_norm_backward_reduce - - batch_norm_elemt - - batch_norm_gather_stats - - batch_norm_gather_stats_with_counts - - bernoulli - - batch_norm_stats - - batch_norm_update_stats - - bilinear - - bincount - - binomial - - binary_cross_entropy_with_logits - - bitwise_and - - bitwise_not - - bitwise_or - - bitwise_xor - - blackman_window - - block_diag - - bmm - - broadcast_tensors - - broadcast_to - - cartesian_prod - - cat - - cdist - - ceil - - ceil_ - - celu - - celu_ - - chain_matmul - - channel_shuffle - - cholesky - - cholesky_inverse - - cholesky_solve - - choose_qparams_optimized - - chunk - - clamp - - clamp_ - - clamp_max - - clamp_max_ - - clamp_min - - clamp_min_ - - clip - - clip_ - - clone - - column_stack - - combinations - - constant_pad_nd - - conv1d - - conv2d - - conv3d - - conv_tbc - - conv_transpose1d - - conv_transpose2d - - conv_transpose3d - - cos - - convolution - - copysign - - cos_ - - cosh - - cosh_ - - cosine_embedding_loss - - cosine_similarity - - count_nonzero - - cross - - ctc_loss - - cummax - - cummin - - cumprod - - cumsum - - deg2rad - - deg2rad_ - - det - - diag - - diag_embed - - diff - - diagflat - - diagonal - - digamma - - dist - - div - - divide - - dot - - dropout - - dropout_ - - dsmm - - dstack - - eig - - einsum - - embedding - - embedding_bag - - embedding_renorm_ - - eq - - equal - - erf - - erf_ - - erfc - - erfc_ - - erfinv - - exp - - exp2 - - exp2_ - - exp_ - - expm1 - - expm1_ - - eye - - feature_dropout - - feature_alpha_dropout - - feature_alpha_dropout_ - - feature_dropout_ - - fix - - fill_ - - fix_ - - flatten - - flip - - fliplr - - flipud - - float_power - - floor - - floor_ - - floor_divide - - fmax - - fmin - - fmod - - frac - - frac_ - - full - - frobenius_norm - - full_like - - gather - - gcd - - gcd_ - - ge - - geqrf - - ger - - greater - - greater_equal - - grid_sampler - - grid_sampler_2d - - group_norm - - grid_sampler_3d - - gru - - gru_cell - - gt - - hamming_window - - hann_window - - hardshrink - - heaviside - - hinge_embedding_loss - - histc - - hsmm - - hspmm - - hstack - - hypot - - igamma - - igammac - - index_add - - index_copy - - inner - - index_fill - - index_put - - index_put_ - - index_select - - instance_norm - - isclose - - isfinite - - isinf - - isnan - - isneginf - - isposinf - - istft - - kaiser_window - - kl_div - - kron - - kthvalue - - layer_norm - - lcm - - lcm_ - - ldexp - - ldexp_ - - le - - lerp - - less - - less_equal - - lgamma - - linspace - - log - - log10 - - log10_ - - log1p - - log1p_ - - log2 - - log2_ - - log_softmax - - log_ - - logaddexp - - logaddexp2 - - logcumsumexp - - logdet - - logical_and - - logical_not - - logical_or - - logical_xor - - logit - - logit_ - - logspace - - logsumexp - - lstm - - lstm_cell - - lstsq - - lt - - lu_solve - - masked_fill - - margin_ranking_loss - - masked_scatter - - masked_select - - matrix_exp - - matmul - - matrix_power - - matrix_rank - - max - - max_pool1d - - max_pool2d - - max_pool1d_with_indices - - max_pool3d - - maximum - - mean - - median - - min - - minimum - - mm - - mode - - moveaxis - - movedim - - msort - - mul - - multinomial - - multiply - - mv - - mvlgamma - - nan_to_num - - nan_to_num_ - - nanmedian - - nansum - - narrow - - native_batch_norm - - native_group_norm - - narrow_copy - - native_layer_norm - - native_norm - - ne - - neg - - negative - - neg_ - - negative_ - - nextafter - - nonzero - - norm_except_dim - - normal - - not_equal - - nuclear_norm - - pairwise_distance - - pdist - - pinverse - - pixel_shuffle - - pixel_unshuffle - - poisson - - poisson_nll_loss - - polar - - polygamma - - pow - - prelu - - prod - - rad2deg - - promote_types - - rad2deg_ - - range - - ravel - - real - - reciprocal - - relu - - reciprocal_ - - relu_ - - remainder - - renorm - - repeat_interleave - - reshape - - resize_as_ - - roll - - rot90 - - round - - round_ - - rrelu - - rrelu_ - - rsqrt - - row_stack - - rsqrt_ - - rsub - - saddmm - - scalar_tensor - - scatter - - select - - scatter_add - - searchsorted - - selu - - selu_ - - sgn - - sigmoid - - sigmoid_ - - sign - - signbit - - sin - - sin_ - - sinc - - sinc_ - - sinh - - sinh_ - - slogdet - - smm - - softmax - - solve - - sort - - sparse_coo_tensor - - square - - split_with_sizes - - spmm - - sqrt - - sqrt_ - - square_ - - squeeze - - sspaddmm - - stack - - std - - std_mean - - sub - - subtract - - sum - - svd - - swapaxes - - swapdims - - symeig - - t - - take - - tan - - tan_ - - tanh - - tanh_ - - tensordot - - tensor_split - - threshold - - threshold_ - - tile - - topk - - transpose - - trapz - - triangular_solve - - tril - - tril_indices - - triplet_margin_loss - - triu - - triu_indices - - true_divide - - trunc - - trunc_ - - unique_consecutive - - xlogy - - unbind - - unique_dim - - unsafe_chunk - - unsafe_split - - vander - - var - - vdot - - unsafe_split_with_sizes - - unsqueeze - - var_mean - - vstack - - where - - xlogy_ +# Copyright (c) 2023 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# List of ops that register hooks + +functional: + - conv1d + - conv2d + - conv3d + - conv_transpose1d + - conv_transpose2d + - conv_transpose3d + - conv_tbc + - avg_pool1d + - avg_pool2d + - avg_pool3d + - fractional_max_pool2d_with_indices + - fractional_max_pool2d + - fractional_max_pool3d_with_indices + - fractional_max_pool3d + - max_pool1d_with_indices + - max_pool1d + - max_pool2d_with_indices + - max_pool2d + - max_pool3d_with_indices + - max_pool3d + - max_unpool1d + - max_unpool2d + - max_unpool3d + - lp_pool2d + - lp_pool1d + - adaptive_max_pool1d_with_indices + - adaptive_max_pool1d + - adaptive_max_pool2d_with_indices + - adaptive_max_pool2d + - adaptive_max_pool3d_with_indices + - adaptive_max_pool3d + - adaptive_avg_pool1d + - adaptive_avg_pool2d + - adaptive_avg_pool3d + - dropout + - alpha_dropout + - dropout2d + - dropout3d + - feature_alpha_dropout + - threshold + - threshold_ + - relu + - relu_ + - glu + - hardtanh + - hardtanh_ + - relu6 + - elu + - elu_ + - selu + - selu_ + - celu + - celu_ + - leaky_relu + - leaky_relu_ + - prelu + - rrelu + - rrelu_ + - logsigmoid + - gelu + - hardshrink + - tanhshrink + - softsign + - softplus + - softmin + - softmax + - gumbel_softmax + - log_softmax + - softshrink + - tanh + - sigmoid + - hardsigmoid + - linear + - bilinear + - silu + - hardswish + - embedding + - embedding_bag + - batch_norm + - instance_norm + - layer_norm + - group_norm + - local_response_norm + - ctc_loss + - nll_loss + - poisson_nll_loss + - gaussian_nll_loss + - kl_div + - cross_entropy + - binary_cross_entropy + - binary_cross_entropy_with_logits + - smooth_l1_loss + - l1_loss + - mse_loss + - margin_ranking_loss + - hinge_embedding_loss + - multilabel_margin_loss + - soft_margin_loss + - multilabel_soft_margin_loss + - cosine_embedding_loss + - multi_margin_loss + - pixel_shuffle + - pixel_unshuffle + - channel_shuffle + - upsample + - interpolate + - upsample_nearest + - upsample_bilinear + - grid_sample + - affine_grid + - pad + - pairwise_distance + - pdist + - cosine_similarity + - one_hot + - triplet_margin_loss + - triplet_margin_with_distance_loss + - normalize + - unfold + - fold + - multi_head_attention_forward + +tensor: + - __add__ + - __and__ + - __bool__ + - __div__ + - __eq__ + - __ge__ + - __gt__ + - __iadd__ + - __iand__ + - __idiv__ + - __ifloordiv__ + - __ilshift__ + - __imod__ + - __imul__ + - __ior__ + - __irshift__ + - __isub__ + - __ixor__ + - __lshift__ + - __matmul__ + - __mod__ + - __mul__ + - __nonzero__ + - __or__ + - __radd__ + - __rmul__ + - __rshift__ + - __sub__ + - __truediv__ + - __xor__ + - abs + - abs_ + - absolute + - absolute_ + - acos + - acos_ + - acosh + - acosh_ + - add + - add_ + - addbmm + - addbmm_ + - addcdiv + - addcdiv_ + - addcmul + - addcmul_ + - addmm + - addmm_ + - addmv + - addmv_ + - addr + - addr_ + - align_as + - align_to + - all + - allclose + - amax + - amin + - angle + - any + - arccos + - arccos_ + - arccosh + - arccosh_ + - arcsin + - arcsin_ + - arcsinh + - arcsinh_ + - arctan + - arctan_ + - arctanh + - arctanh_ + - argmax + - argmin + - argsort + - asin + - asin_ + - asinh + - asinh_ + - atan + - atan2 + - atan2_ + - atan_ + - atanh + - atanh_ + - baddbmm + - baddbmm_ + - bernoulli + - bernoulli_ + - bincount + - bitwise_and + - bitwise_and_ + - bitwise_not + - bitwise_not_ + - bitwise_or + - bitwise_or_ + - bitwise_xor + - bitwise_xor_ + - bmm + - broadcast_to + - cauchy_ + - ceil + - ceil_ + - cholesky + - chunk + - clamp + - cholesky_solve + - cholesky_inverse + - clamp_ + - clamp_max + - clamp_max_ + - clip + - clamp_min + - clamp_min_ + - clip_ + - copysign + - copysign_ + - cos + - cos_ + - cosh + - cosh_ + - count_nonzero + - cummax + - cummin + - cumprod + - cumprod_ + - cumsum + - cumsum_ + - deg2rad + - deg2rad_ + - det + - diag + - diag_embed + - diagflat + - diagonal + - diff + - dist + - digamma + - digamma_ + - div + - div_ + - divide + - divide_ + - dot + - eig + - eq + - eq_ + - erf + - equal + - erf_ + - erfc + - erfc_ + - erfinv + - erfinv_ + - exp + - exp2 + - exp2_ + - expm1 + - exp_ + - expm1_ + - exponential_ + - fill_ + - fix + - fill_diagonal_ + - fix_ + - flip + - fliplr + - flatten + - flipud + - float_power + - float_power_ + - floor + - floor_ + - floor_divide + - floor_divide_ + - fmax + - fmin + - fmod + - fmod_ + - frac + - frac_ + - gather + - gcd + - gcd_ + - ge + - ge_ + - geometric_ + - geqrf + - ger + - greater + - greater_ + - gt + - gt_ + - greater_equal + - greater_equal_ + - hardshrink + - heaviside + - heaviside_ + - histc + - hypot + - hypot_ + - igamma + - igamma_ + - igammac + - igammac_ + - index_add + - index_add_ + - inverse + - index_copy + - index_copy_ + - index_fill + - index_fill_ + - index_put + - index_put_ + - inner + - index_select + - isclose + - isfinite + - isinf + - isnan + - isneginf + - isposinf + - isreal + - kron + - kthvalue + - lcm + - lcm_ + - ldexp + - ldexp_ + - le + - le_ + - lerp + - lerp_ + - where + - less + - less_ + - less_equal + - less_equal_ + - lgamma + - lgamma_ + - log + - log10 + - log10_ + - log1p + - log1p_ + - log2 + - log2_ + - log_ + - log_normal_ + - log_softmax + - logcumsumexp + - logdet + - logaddexp + - logaddexp2 + - logical_and + - logical_and_ + - logical_not + - logit + - logical_not_ + - logical_or + - logical_or_ + - logical_xor + - logical_xor_ + - logit_ + - logsumexp + - lstsq + - lt + - lt_ + - lu_solve + - map2_ + - map_ + - masked_fill + - matmul + - masked_fill_ + - masked_scatter + - masked_scatter_ + - masked_select + - matrix_exp + - max + - maximum + - mean + - matrix_power + - median + - min + - minimum + - mm + - mode + - msort + - mul + - mul_ + - multinomial + - multiply + - multiply_ + - mv + - mvlgamma + - mvlgamma_ + - nansum + - narrow + - narrow_copy + - ne + - ne_ + - neg + - neg_ + - negative + - negative_ + - nonzero + - normal_ + - not_equal + - not_equal_ + - permute + - pinverse + - polygamma + - pow + - pow_ + - polygamma_ + - prelu + - prod + - put_ + - rad2deg + - rad2deg_ + - ravel + - real + - reciprocal + - reciprocal_ + - relu + - relu_ + - remainder + - repeat_interleave + - reshape + - remainder_ + - renorm + - renorm_ + - repeat + - reshape_as + - resize_ + - resize_as_ + - roll + - rot90 + - round + - round_ + - rsqrt + - rsqrt_ + - scatter + - scatter_ + - scatter_add + - scatter_add_ + - select + - sgn + - sgn_ + - sigmoid + - sigmoid_ + - sign + - sign_ + - signbit + - sin + - sin_ + - sinc + - sinc_ + - sinh + - sinh_ + - slogdet + - smm + - softmax + - solve + - sort + - split_with_sizes + - sqrt + - sqrt_ + - square + - square_ + - squeeze + - squeeze_ + - sspaddmm + - std + - sub + - sub_ + - sum + - sum_to_size + - svd + - symeig + - t + - t_ + - take + - tan + - tan_ + - tanh + - tanh_ + - tensor_split + - tile + - topk + - transpose + - transpose_ + - triangular_solve + - tril + - tril_ + - triu + - true_divide + - triu_ + - true_divide_ + - trunc + - trunc_ + - type_as + - unbind + - unflatten + - unfold + - unsafe_chunk + - unsqueeze + - unsafe_split + - unsafe_split_with_sizes + - var + - vdot + - unsqueeze_ + - view_as + - xlogy + - xlogy_ + +torch: + - _adaptive_avg_pool2d + - _add_relu + - _add_relu_ + - _aminmax + - _batch_norm_impl_index + - _convolution + - abs + - abs_ + - absolute + - acos + - acos_ + - acosh + - acosh_ + - adaptive_avg_pool1d + - adaptive_max_pool1d + - add + - addbmm + - addcdiv + - addcmul + - addmm + - addmv + - addmv_ + - addr + - amax + - affine_grid_generator + - align_tensors + - all + - alpha_dropout + - amin + - alpha_dropout_ + - angle + - any + - arange + - arccos + - arccos_ + - arccosh + - arccosh_ + - arcsin + - arcsin_ + - arcsinh + - arcsinh_ + - arctan + - arctan_ + - arctanh + - arctanh_ + - argmax + - argmin + - argsort + - asin + - asin_ + - asinh + - asinh_ + - atan + - atan2 + - atan_ + - atanh + - atanh_ + - atleast_1d + - atleast_2d + - atleast_3d + - avg_pool1d + - baddbmm + - bartlett_window + - batch_norm_backward_elemt + - batch_norm_backward_reduce + - batch_norm_elemt + - batch_norm_gather_stats + - batch_norm_gather_stats_with_counts + - bernoulli + - batch_norm_stats + - batch_norm_update_stats + - bilinear + - bincount + - binomial + - binary_cross_entropy_with_logits + - bitwise_and + - bitwise_not + - bitwise_or + - bitwise_xor + - blackman_window + - block_diag + - bmm + - broadcast_tensors + - broadcast_to + - cartesian_prod + - cat + - cdist + - ceil + - ceil_ + - celu + - celu_ + - chain_matmul + - channel_shuffle + - cholesky + - cholesky_inverse + - cholesky_solve + - choose_qparams_optimized + - chunk + - clamp + - clamp_ + - clamp_max + - clamp_max_ + - clamp_min + - clamp_min_ + - clip + - clip_ + - clone + - column_stack + - combinations + - constant_pad_nd + - conv1d + - conv2d + - conv3d + - conv_tbc + - conv_transpose1d + - conv_transpose2d + - conv_transpose3d + - cos + - convolution + - copysign + - cos_ + - cosh + - cosh_ + - cosine_embedding_loss + - cosine_similarity + - count_nonzero + - cross + - ctc_loss + - cummax + - cummin + - cumprod + - cumsum + - deg2rad + - deg2rad_ + - det + - diag + - diag_embed + - diff + - diagflat + - diagonal + - digamma + - dist + - div + - divide + - dot + - dropout + - dropout_ + - dsmm + - dstack + - eig + - einsum + - embedding + - embedding_bag + - embedding_renorm_ + - eq + - equal + - erf + - erf_ + - erfc + - erfc_ + - erfinv + - exp + - exp2 + - exp2_ + - exp_ + - expm1 + - expm1_ + - eye + - feature_dropout + - feature_alpha_dropout + - feature_alpha_dropout_ + - feature_dropout_ + - fix + - fill_ + - fix_ + - flatten + - flip + - fliplr + - flipud + - float_power + - floor + - floor_ + - floor_divide + - fmax + - fmin + - fmod + - frac + - frac_ + - full + - frobenius_norm + - full_like + - gather + - gcd + - gcd_ + - ge + - geqrf + - ger + - greater + - greater_equal + - grid_sampler + - grid_sampler_2d + - group_norm + - grid_sampler_3d + - gru + - gru_cell + - gt + - hamming_window + - hann_window + - hardshrink + - heaviside + - hinge_embedding_loss + - histc + - hsmm + - hspmm + - hstack + - hypot + - igamma + - igammac + - index_add + - index_copy + - inner + - index_fill + - index_put + - index_put_ + - index_select + - instance_norm + - isclose + - isfinite + - isinf + - isnan + - isneginf + - isposinf + - istft + - kaiser_window + - kl_div + - kron + - kthvalue + - layer_norm + - lcm + - lcm_ + - ldexp + - ldexp_ + - le + - lerp + - less + - less_equal + - lgamma + - linspace + - log + - log10 + - log10_ + - log1p + - log1p_ + - log2 + - log2_ + - log_softmax + - log_ + - logaddexp + - logaddexp2 + - logcumsumexp + - logdet + - logical_and + - logical_not + - logical_or + - logical_xor + - logit + - logit_ + - logspace + - logsumexp + - lstm + - lstm_cell + - lstsq + - lt + - lu_solve + - masked_fill + - margin_ranking_loss + - masked_scatter + - masked_select + - matrix_exp + - matmul + - matrix_power + - matrix_rank + - max + - max_pool1d + - max_pool2d + - max_pool1d_with_indices + - max_pool3d + - maximum + - mean + - median + - min + - minimum + - mm + - mode + - moveaxis + - movedim + - msort + - mul + - multinomial + - multiply + - mv + - mvlgamma + - nan_to_num + - nan_to_num_ + - nanmedian + - nansum + - narrow + - native_batch_norm + - native_group_norm + - narrow_copy + - native_layer_norm + - native_norm + - ne + - neg + - negative + - neg_ + - negative_ + - nextafter + - nonzero + - norm_except_dim + - normal + - not_equal + - nuclear_norm + - pairwise_distance + - pdist + - pinverse + - pixel_shuffle + - pixel_unshuffle + - poisson + - poisson_nll_loss + - polar + - polygamma + - pow + - prelu + - prod + - rad2deg + - promote_types + - rad2deg_ + - range + - ravel + - real + - reciprocal + - relu + - reciprocal_ + - relu_ + - remainder + - renorm + - repeat_interleave + - reshape + - resize_as_ + - roll + - rot90 + - round + - round_ + - rrelu + - rrelu_ + - rsqrt + - row_stack + - rsqrt_ + - rsub + - saddmm + - scalar_tensor + - scatter + - select + - scatter_add + - searchsorted + - selu + - selu_ + - sgn + - sigmoid + - sigmoid_ + - sign + - signbit + - sin + - sin_ + - sinc + - sinc_ + - sinh + - sinh_ + - slogdet + - smm + - softmax + - solve + - sort + - sparse_coo_tensor + - square + - split_with_sizes + - spmm + - sqrt + - sqrt_ + - square_ + - squeeze + - sspaddmm + - stack + - std + - std_mean + - sub + - subtract + - sum + - svd + - swapaxes + - swapdims + - symeig + - t + - take + - tan + - tan_ + - tanh + - tanh_ + - tensordot + - tensor_split + - threshold + - threshold_ + - tile + - topk + - transpose + - trapz + - triangular_solve + - tril + - tril_indices + - triplet_margin_loss + - triu + - triu_indices + - true_divide + - trunc + - trunc_ + - unique_consecutive + - xlogy + - unbind + - unique_dim + - unsafe_chunk + - unsafe_split + - vander + - var + - vdot + - unsafe_split_with_sizes + - unsqueeze + - var_mean + - vstack + - where + - xlogy_ + +torch_npu: + - one_ + - npu_sort_v2 + - npu_transpose + - npu_broadcast + - npu_dtype_cast + - empty_with_format + - npu_one_hot + - npu_stride_add + - npu_ps_roi_pooling + - npu_roi_align + - npu_nms_v4 + - npu_iou + - npu_nms_with_mask + - npu_pad + - npu_bounding_box_encode + - npu_bounding_box_decode + - npu_batch_nms + - npu_slice + - _npu_dropout + - npu_indexing + - npu_ifmr + - npu_max + - npu_scatter + - npu_layer_norm_eval + - npu_alloc_float_status + - npu_confusion_transpose + - npu_bmmV2 + - fast_gelu + - npu_sub_sample + - npu_deformable_conv2d + - npu_mish + - npu_anchor_response_flags + - npu_yolo_boxes_encode + - npu_grid_assign_positive + - npu_normalize_batch + - npu_masked_fill_range + - npu_linear + - npu_bert_apply_adam + - npu_giou + - npu_ciou + - npu_diou + - npu_sign_bits_pack + - npu_sign_bits_unpack + - npu_flash_attention + - npu_scaled_masked_softmax + - npu_rotary_mul + - npu_roi_align + - npu_roi_alignbk + - npu_ptiou + - npu_fusion_attention + +distributed: + - send + - recv + - broadcast + - all_reduce + - reduce + - all_gather + - gather + - isend + - irecv + - scatter + - reduce_scatter + - _reduce_scatter_base + - _all_gather_base diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/utils.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/utils.py index 7d16ac993ed45faa0f9b48bb64050592e15ef4d2..85bb64df1ca7c950ca8d5ee0668ce7841fd0aa31 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/utils.py @@ -1,29 +1,31 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import os -import yaml - -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with FileOpen(yaml_path, 'r') as f: - Ops = yaml.safe_load(f) - WrapFunctionalOps = Ops.get('functional') - WrapTensorOps = Ops.get('tensor') - WrapTorchOps = Ops.get('torch') \ No newline at end of file +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import yaml + +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with FileOpen(yaml_path, 'r') as f: + Ops = yaml.safe_load(f) + WrapFunctionalOps = Ops.get('functional') + WrapTensorOps = Ops.get('tensor') + WrapTorchOps = Ops.get('torch') + WrapDistributedOps = yaml.safe_load(f).get('distributed') + WrapNpuOps = yaml.safe_load(f).get('torch_npu') diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_distributed.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..f01c5305458dbef2b3f949ed0196a2c20c5b426b --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_distributed.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os + +import torch.distributed as dist +import yaml + +from api_accuracy_checker.hook_module.hook_module import HOOKModule +from api_accuracy_checker.common.utils import torch_device_guard, Const +from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.hook_module.utils import WrapDistributedOps +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen + + +distributed_func = {} +for f in dir(dist): + distributed_func[f] = getattr(dist, f) + + +def get_distributed_ops(): + global WrapDistributedOps + _all_distributed_ops = dir(dist) + if msCheckerConfig.white_list: + set(WrapDistributedOps) & set(_all_distributed_ops) & set(msCheckerConfig.white_list) + else: + return set(WrapDistributedOps) & set(_all_distributed_ops) + + +class HOOKDistributedOP(object): + pass + + +class DistributedOPTemplate(HOOKModule): + def __init__(self, op_name, hook): + self.op_name_ = op_name + self.prefix_op_name_ = "Distributed_" + str(op_name) + "_" + super().__init__(hook) + if self.op_name_ in Const.INPLACE_LIST: + self.register_forward_pre_hook(hook(self.prefix + Const.PRE_FORWARD)) + + @torch_device_guard + def forward(self, *args, **kwargs): + return distributed_func.get(self.op_name_)(*args, **kwargs) + + +def wrap_distributed_op(op_name, hook): + def distributed_op_template(*args, **kwargs): + return DistributedOPTemplate(op_name, hook)(*args, **kwargs) + + return distributed_op_template + + +def wrap_distributed_ops_and_bind(hook): + _distributed_ops = get_distributed_ops() + for op_name in _distributed_ops: + setattr(HOOKDistributedOP, "wrap_" + str(op_name), wrap_distributed_op(op_name, hook)) diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_npu_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..688b5db72abff091518efc9c788c665e737f13ab --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_npu_custom.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import torch +import torch_npu +import yaml + +from api_accuracy_checker.hook_module.hook_module import HOOKModule +from api_accuracy_checker.common.utils import torch_device_guard, torch_without_guard_version +from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.hook_module.utils import WrapNpuOps +from ..common.file_check_util import FileOpen + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with FileOpen(yaml_path, 'r') as f: + WrapNpuOps = yaml.safe_load(f).get('torch_npu') + + +def get_npu_ops(): + global WrapNpuOps + if torch_without_guard_version: + _npu_ops = dir(torch.ops.npu) + else: + _npu_ops = dir(torch_npu._C._VariableFunctionsClass) + + if msCheckerConfig.white_list: + return set(WrapNpuOps) & set(_npu_ops) & set(msCheckerConfig.white_list) + else: + return set(WrapNpuOps) & set(_npu_ops) + + +class HOOKNpuOP(object): + pass + + +class NpuOPTemplate(HOOKModule): + + def __init__(self, op_name, hook): + self.op_name_ = op_name + self.prefix_op_name_ = "NPU_" + str(op_name) + "_" + super().__init__(hook) + + @torch_device_guard + def forward(self, *args, **kwargs): + if torch_without_guard_version: + return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs) + else: + return getattr(torch_npu._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs) + + +def wrap_npu_op(op_name, hook): + + def npu_op_template(*args, **kwargs): + return NpuOPTemplate(op_name, hook)(*args, **kwargs) + + return npu_op_template + + +def wrap_npu_ops_and_bind(hook): + _npu_ops = get_npu_ops() + for op_name in _npu_ops: + setattr(HOOKNpuOP, "wrap_" + str(op_name), wrap_npu_op(op_name, hook)) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 0c0f3305c7104e87f64d6002996ed63c342c2eb9..1de0944eaa5e838cd908d05ae963e32709b8f200 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -1,291 +1,293 @@ -import argparse -import os -import copy -import sys -import time - -try: - import torch_npu -except ImportError: - is_gpu = True - current_device = "cuda" -else: - is_gpu = False - current_device = "npu" - -import yaml -import torch -from tqdm import tqdm -from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args -from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ - print_error_log, check_file_or_directory_path, initialize_save_path, Const -from api_accuracy_checker.compare.compare import Comparator -from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate -from api_accuracy_checker.hook_module.wrap_functional import FunctionalOPTemplate -from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate -from api_accuracy_checker.run_ut.ut_api_info import UtAPIInfo -from api_accuracy_checker.common.config import msCheckerConfig - -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen, FileCheckConst, FileChecker, \ - change_mode, check_file_suffix, check_link - -ut_error_data_dir = 'ut_error_data' - - -def exec_api(api_type, api_name, args, kwargs): - if api_type == "Functional": - functional_api = FunctionalOPTemplate(api_name, str, False) - out = functional_api.forward(*args, **kwargs) - if api_type == "Tensor": - tensor_api = TensorOPTemplate(api_name, str, False) - out = tensor_api.forward(*args, **kwargs) - if api_type == "Torch": - torch_api = TorchOPTemplate(api_name, str, False) - out = torch_api.forward(*args, **kwargs) - return out - - -def deal_detach(arg, to_detach=True): - return arg.detach() if to_detach else arg - - -def generate_device_params(input_args, input_kwargs, need_backward): - def recursive_arg_to_device(arg_in, to_detach=True): - if isinstance(arg_in, (list, tuple)): - return type(arg_in)(recursive_arg_to_device(arg, to_detach) for arg in arg_in) - elif isinstance(arg_in, torch.Tensor): - if need_backward and arg_in.requires_grad: - arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_() - temp_arg_in = arg_in * 1 - arg_in = temp_arg_in.type_as(arg_in) - arg_in.retain_grad() - return arg_in - else: - return deal_detach(arg_in.clone(), to_detach).to(current_device) - else: - return arg_in - - device_args = recursive_arg_to_device(input_args) - device_kwargs = {key: recursive_arg_to_device(value, key != "out") for key, value in input_kwargs.items()} - return device_args, device_kwargs - - -def generate_cpu_params(input_args, input_kwargs, need_backward): - first_dtype = None - - def recursive_arg_to_cpu(arg_in, to_detach=True): - nonlocal first_dtype - if isinstance(arg_in, (list, tuple)): - return type(arg_in)(recursive_arg_to_cpu(arg, to_detach) for arg in arg_in) - elif isinstance(arg_in, torch.Tensor): - if need_backward and arg_in.requires_grad: - if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: - arg_in = deal_detach(arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])), to_detach).requires_grad_() - if first_dtype is None: - first_dtype = arg_in.dtype - else: - arg_in = deal_detach(arg_in.clone(), to_detach).requires_grad_() - temp_arg_in = arg_in * 1 - arg_in = temp_arg_in.type_as(arg_in) - arg_in.retain_grad() - return arg_in - else: - if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: - arg_in = deal_detach(arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])), to_detach) - if first_dtype is None: - first_dtype = arg_in.dtype - return arg_in - return deal_detach(arg_in.clone(), to_detach) - else: - return arg_in - - cpu_args = recursive_arg_to_cpu(input_args) - cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out") for key, value in input_kwargs.items()} - return cpu_args, cpu_kwargs - - -def run_ut(forward_file, backward_file, out_path, save_error_data): - print_info_log("start UT test") - forward_content = get_json_contents(forward_file) - backward_content = get_json_contents(backward_file) - api_setting_dict = get_json_contents("torch_ut_setting.json") - compare = Comparator(out_path) - for api_full_name, api_info_dict in tqdm(forward_content.items()): - try: - if msCheckerConfig.white_list: - [_, api_name, _] = api_full_name.split("*") - if api_name not in set(msCheckerConfig.white_list): - continue - data_info = run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict) - is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, - data_info.bench_out, - data_info.device_out, - data_info.bench_grad_out, - data_info.device_grad_out) - if save_error_data: - do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) - except Exception as err: - [_, api_name, _] = api_full_name.split("*") - if "expected scalar type Long" in str(err): - print_warn_log(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " - f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") - else: - print_error_log(f"Run {api_full_name} UT Error: %s" % str(err)) - compare.write_summary_csv((api_full_name, "SKIP", "SKIP", str(err))) - change_mode(compare.save_path, FileCheckConst.DATA_FILE_AUTHORITY) - change_mode(compare.detail_save_path, FileCheckConst.DATA_FILE_AUTHORITY) - compare.print_pretest_result() - - -def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): - if not is_fwd_success or not is_bwd_success: - api_full_name = api_full_name.replace("*", ".") - for element in data_info.in_fwd_data_list: - UtAPIInfo(api_full_name + '.forward.input', element, ut_error_data_dir) - UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_out, ut_error_data_dir) - UtAPIInfo(api_full_name + '.forward.output.device', data_info.device_out, ut_error_data_dir) - UtAPIInfo(api_full_name + '.backward.input', data_info.grad_in, ut_error_data_dir) - UtAPIInfo(api_full_name + '.backward.output.bench', data_info.bench_grad_out, ut_error_data_dir) - UtAPIInfo(api_full_name + '.backward.output.device', data_info.device_grad_out, ut_error_data_dir) - - -def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): - in_fwd_data_list = [] - [api_type, api_name, _] = api_full_name.split("*") - args, kwargs, need_grad = get_api_info(api_info_dict, api_name) - in_fwd_data_list.append(args) - in_fwd_data_list.append(kwargs) - need_backward = api_full_name in backward_content - need_backward = need_backward and need_grad - if not need_grad: - print_warn_log("%s function with out=... arguments don't support automatic differentiation, skip backward." % api_full_name) - if kwargs.get("device"): - del kwargs["device"] - cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward) - device_args, device_kwargs = generate_device_params(args, kwargs, need_backward) - grad_out, device_grad_out = None, None - out = exec_api(api_type, api_name, cpu_args, cpu_kwargs) - device_out = exec_api(api_type, api_name, device_args, device_kwargs) - grad_input_index = api_setting_dict.get(api_name) - grad_index = None - grad = None - if grad_input_index is not None: - grad_index = grad_input_index.get('grad_index') - - if need_backward: - grad_out, device_grad_out, grad, device_grad = run_backward(api_full_name, cpu_args, backward_content, grad_index, device_args, - device_out, out) - if grad_index is not None: - return UtDataInfo(grad_out, device_grad_out, device_out[grad_index], out[grad_index], grad, in_fwd_data_list) - return UtDataInfo(grad_out, device_grad_out, device_out, out, grad, in_fwd_data_list) - - -def get_api_info(api_info_dict, api_name): - convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict) - need_grad = True - if api_info_dict.get("kwargs") and "out" in api_info_dict.get("kwargs"): - need_grad = False - args, kwargs = gen_api_params(api_info_dict, need_grad, convert_type) - return args, kwargs, need_grad - - -def run_backward(api_full_name, args, backward_content, grad_index, device_args, device_out, out): - backward_args = backward_content[api_full_name] - grad = gen_args(backward_args)[0] - cpu_grad, _ = generate_cpu_params(grad, {}, False) - if grad_index is not None: - out[grad_index].backward(cpu_grad) - elif isinstance(out, (list, tuple)): - raise NotImplementedError("Multiple backward is not supported.") - else: - out.backward(cpu_grad) - args_grad = [] - for arg in args: - if isinstance(arg, torch.Tensor): - args_grad.append(arg.grad) - grad_out = args_grad - device_grad = grad.clone().detach().to(current_device) - if grad_index is not None: - device_out[grad_index].backward(device_grad) - else: - device_out.backward(device_grad) - device_args_grad = [] - for arg in device_args: - if isinstance(arg, torch.Tensor): - device_args_grad.append(arg.grad) - device_grad_out = device_args_grad - return grad_out, device_grad_out, grad, device_grad - - -def initialize_save_error_data(): - error_data_path_checker = FileChecker(msCheckerConfig.error_data_path, FileCheckConst.DIR, - ability=FileCheckConst.WRITE_ABLE) - error_data_path = error_data_path_checker.common_check() - global ut_error_data_dir - ut_error_data_dir = 'ut_error_data' + time.strftime("%Y%m%d%H%M%S") - initialize_save_path(error_data_path, ut_error_data_dir) - - -def _run_ut_parser(parser): - parser.add_argument("-forward", "--forward_input_file", dest="forward_input_file", default="", type=str, - help=" The api param tool forward result file: generate from api param tool, " - "a json file.", - required=True) - parser.add_argument("-backward", "--backward_input_file", dest="backward_input_file", default="", type=str, - help=" The api param tool backward result file: generate from api param tool, " - "a json file.", - required=True) - parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, - help=" The ut task result out path.", - required=False) - parser.add_argument('-save_error_data', dest="save_error_data", action="store_true", - help=" Save compare failed api output.", required=False) - parser.add_argument("-j", "--jit_compile", dest="jit_compile", action="store_true", - help=" whether to turn on jit compile", required=False) - parser.add_argument("-d", "--device", dest="device_id", type=int, help=" set device id to run ut", - default=0, required=False) - - -def _run_ut(): - parser = argparse.ArgumentParser() - _run_ut_parser(parser) - args = parser.parse_args(sys.argv[1:]) - if not is_gpu: - torch.npu.set_compile_mode(jit_compile=args.jit_compile) - used_device = current_device + ":" + str(args.device_id) - try: - if is_gpu: - torch.cuda.set_device(used_device) - else: - torch.npu.set_device(used_device) - except Exception as error: - print_error_log(f"Set device id failed. device id is: {args.device_id}") - raise NotImplementedError from error - check_link(args.forward_input_file) - check_link(args.backward_input_file) - forward_file = os.path.realpath(args.forward_input_file) - backward_file = os.path.realpath(args.backward_input_file) - check_file_suffix(forward_file, FileCheckConst.JSON_SUFFIX) - check_file_suffix(backward_file, FileCheckConst.JSON_SUFFIX) - out_path = os.path.realpath(args.out_path) if args.out_path else "./" - out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) - out_path = out_path_checker.common_check() - save_error_data = args.save_error_data - if save_error_data: - initialize_save_error_data() - run_ut(forward_file, backward_file, out_path, save_error_data) - - -class UtDataInfo: - def __init__(self, bench_grad_out, device_grad_out, device_out, bench_out, grad_in, in_fwd_data_list): - self.bench_grad_out = bench_grad_out - self.device_grad_out = device_grad_out - self.device_out = device_out - self.bench_out = bench_out - self.grad_in = grad_in - self.in_fwd_data_list = in_fwd_data_list - - -if __name__ == '__main__': - _run_ut() - print_info_log("UT task completed.") +import argparse +import os +import copy +import sys +import time + +try: + import torch_npu +except ImportError: + is_gpu = True + current_device = "cuda" +else: + is_gpu = False + current_device = "npu" + +import yaml +import torch +from tqdm import tqdm +from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args +from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ + print_error_log, check_file_or_directory_path, initialize_save_path, Const +from api_accuracy_checker.compare.compare import Comparator +from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate +from api_accuracy_checker.hook_module.wrap_functional import FunctionalOPTemplate +from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate +from api_accuracy_checker.run_ut.ut_api_info import UtAPIInfo +from api_accuracy_checker.common.config import msCheckerConfig + +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen, FileCheckConst, FileChecker, \ + change_mode, check_file_suffix, check_link + +ut_error_data_dir = 'ut_error_data' + + +def exec_api(api_type, api_name, args, kwargs): + if api_type == "Functional": + functional_api = FunctionalOPTemplate(api_name, str, False) + out = functional_api.forward(*args, **kwargs) + if api_type == "Tensor": + tensor_api = TensorOPTemplate(api_name, str, False) + out = tensor_api.forward(*args, **kwargs) + if api_type == "Torch": + torch_api = TorchOPTemplate(api_name, str, False) + out = torch_api.forward(*args, **kwargs) + return out + + +def deal_detach(arg, to_detach=True): + return arg.detach() if to_detach else arg + + +def generate_device_params(input_args, input_kwargs, need_backward): + def recursive_arg_to_device(arg_in, to_detach=True): + if isinstance(arg_in, (list, tuple)): + return type(arg_in)(recursive_arg_to_device(arg, to_detach) for arg in arg_in) + elif isinstance(arg_in, torch.Tensor): + if need_backward and arg_in.requires_grad: + arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_() + temp_arg_in = arg_in * 1 + arg_in = temp_arg_in.type_as(arg_in) + arg_in.retain_grad() + return arg_in + else: + return deal_detach(arg_in.clone(), to_detach).to(current_device) + else: + return arg_in + + device_args = recursive_arg_to_device(input_args) + device_kwargs = {key: recursive_arg_to_device(value, key != "out") for key, value in input_kwargs.items()} + return device_args, device_kwargs + + +def generate_cpu_params(input_args, input_kwargs, need_backward): + first_dtype = None + + def recursive_arg_to_cpu(arg_in, to_detach=True): + nonlocal first_dtype + if isinstance(arg_in, (list, tuple)): + return type(arg_in)(recursive_arg_to_cpu(arg, to_detach) for arg in arg_in) + elif isinstance(arg_in, torch.Tensor): + if need_backward and arg_in.requires_grad: + if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: + arg_in = deal_detach(arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])), to_detach).requires_grad_() + if first_dtype is None: + first_dtype = arg_in.dtype + else: + arg_in = deal_detach(arg_in.clone(), to_detach).requires_grad_() + temp_arg_in = arg_in * 1 + arg_in = temp_arg_in.type_as(arg_in) + arg_in.retain_grad() + return arg_in + else: + if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: + arg_in = deal_detach(arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])), to_detach) + if first_dtype is None: + first_dtype = arg_in.dtype + return arg_in + return deal_detach(arg_in.clone(), to_detach) + else: + return arg_in + + cpu_args = recursive_arg_to_cpu(input_args) + cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out") for key, value in input_kwargs.items()} + return cpu_args, cpu_kwargs + + +def run_ut(forward_file, backward_file, out_path, save_error_data): + print_info_log("start UT test") + forward_content = get_json_contents(forward_file) + backward_content = get_json_contents(backward_file) + api_setting_dict = get_json_contents("torch_ut_setting.json") + compare = Comparator(out_path) + for api_full_name, api_info_dict in tqdm(forward_content.items()): + try: + if msCheckerConfig.white_list: + [_, api_name, _] = api_full_name.split("*") + if api_name not in set(msCheckerConfig.white_list): + continue + if "Distributed" in api_full_name or "NPU" in api_full_name: + continue + data_info = run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict) + is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, + data_info.bench_out, + data_info.device_out, + data_info.bench_grad_out, + data_info.device_grad_out) + if save_error_data: + do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) + except Exception as err: + [_, api_name, _] = api_full_name.split("*") + if "expected scalar type Long" in str(err): + print_warn_log(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " + f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") + else: + print_error_log(f"Run {api_full_name} UT Error: %s" % str(err)) + compare.write_summary_csv((api_full_name, "SKIP", "SKIP", str(err))) + change_mode(compare.save_path, FileCheckConst.DATA_FILE_AUTHORITY) + change_mode(compare.detail_save_path, FileCheckConst.DATA_FILE_AUTHORITY) + compare.print_pretest_result() + + +def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): + if not is_fwd_success or not is_bwd_success: + api_full_name = api_full_name.replace("*", ".") + for element in data_info.in_fwd_data_list: + UtAPIInfo(api_full_name + '.forward.input', element, ut_error_data_dir) + UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_out, ut_error_data_dir) + UtAPIInfo(api_full_name + '.forward.output.device', data_info.device_out, ut_error_data_dir) + UtAPIInfo(api_full_name + '.backward.input', data_info.grad_in, ut_error_data_dir) + UtAPIInfo(api_full_name + '.backward.output.bench', data_info.bench_grad_out, ut_error_data_dir) + UtAPIInfo(api_full_name + '.backward.output.device', data_info.device_grad_out, ut_error_data_dir) + + +def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): + in_fwd_data_list = [] + [api_type, api_name, _] = api_full_name.split("*") + args, kwargs, need_grad = get_api_info(api_info_dict, api_name) + in_fwd_data_list.append(args) + in_fwd_data_list.append(kwargs) + need_backward = api_full_name in backward_content + need_backward = need_backward and need_grad + if not need_grad: + print_warn_log("%s function with out=... arguments don't support automatic differentiation, skip backward." % api_full_name) + if kwargs.get("device"): + del kwargs["device"] + cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward) + device_args, device_kwargs = generate_device_params(args, kwargs, need_backward) + grad_out, device_grad_out = None, None + out = exec_api(api_type, api_name, cpu_args, cpu_kwargs) + device_out = exec_api(api_type, api_name, device_args, device_kwargs) + grad_input_index = api_setting_dict.get(api_name) + grad_index = None + grad = None + if grad_input_index is not None: + grad_index = grad_input_index.get('grad_index') + + if need_backward: + grad_out, device_grad_out, grad, device_grad = run_backward(api_full_name, cpu_args, backward_content, grad_index, device_args, + device_out, out) + if grad_index is not None: + return UtDataInfo(grad_out, device_grad_out, device_out[grad_index], out[grad_index], grad, in_fwd_data_list) + return UtDataInfo(grad_out, device_grad_out, device_out, out, grad, in_fwd_data_list) + + +def get_api_info(api_info_dict, api_name): + convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict) + need_grad = True + if api_info_dict.get("kwargs") and "out" in api_info_dict.get("kwargs"): + need_grad = False + args, kwargs = gen_api_params(api_info_dict, need_grad, convert_type) + return args, kwargs, need_grad + + +def run_backward(api_full_name, args, backward_content, grad_index, device_args, device_out, out): + backward_args = backward_content[api_full_name] + grad = gen_args(backward_args)[0] + cpu_grad, _ = generate_cpu_params(grad, {}, False) + if grad_index is not None: + out[grad_index].backward(cpu_grad) + elif isinstance(out, (list, tuple)): + raise NotImplementedError("Multiple backward is not supported.") + else: + out.backward(cpu_grad) + args_grad = [] + for arg in args: + if isinstance(arg, torch.Tensor): + args_grad.append(arg.grad) + grad_out = args_grad + device_grad = grad.clone().detach().to(current_device) + if grad_index is not None: + device_out[grad_index].backward(device_grad) + else: + device_out.backward(device_grad) + device_args_grad = [] + for arg in device_args: + if isinstance(arg, torch.Tensor): + device_args_grad.append(arg.grad) + device_grad_out = device_args_grad + return grad_out, device_grad_out, grad, device_grad + + +def initialize_save_error_data(): + error_data_path_checker = FileChecker(msCheckerConfig.error_data_path, FileCheckConst.DIR, + ability=FileCheckConst.WRITE_ABLE) + error_data_path = error_data_path_checker.common_check() + global ut_error_data_dir + ut_error_data_dir = 'ut_error_data' + time.strftime("%Y%m%d%H%M%S") + initialize_save_path(error_data_path, ut_error_data_dir) + + +def _run_ut_parser(parser): + parser.add_argument("-forward", "--forward_input_file", dest="forward_input_file", default="", type=str, + help=" The api param tool forward result file: generate from api param tool, " + "a json file.", + required=True) + parser.add_argument("-backward", "--backward_input_file", dest="backward_input_file", default="", type=str, + help=" The api param tool backward result file: generate from api param tool, " + "a json file.", + required=True) + parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, + help=" The ut task result out path.", + required=False) + parser.add_argument('-save_error_data', dest="save_error_data", action="store_true", + help=" Save compare failed api output.", required=False) + parser.add_argument("-j", "--jit_compile", dest="jit_compile", action="store_true", + help=" whether to turn on jit compile", required=False) + parser.add_argument("-d", "--device", dest="device_id", type=int, help=" set device id to run ut", + default=0, required=False) + + +def _run_ut(): + parser = argparse.ArgumentParser() + _run_ut_parser(parser) + args = parser.parse_args(sys.argv[1:]) + if not is_gpu: + torch.npu.set_compile_mode(jit_compile=args.jit_compile) + used_device = current_device + ":" + str(args.device_id) + try: + if is_gpu: + torch.cuda.set_device(used_device) + else: + torch.npu.set_device(used_device) + except Exception as error: + print_error_log(f"Set device id failed. device id is: {args.device_id}") + raise NotImplementedError from error + check_link(args.forward_input_file) + check_link(args.backward_input_file) + forward_file = os.path.realpath(args.forward_input_file) + backward_file = os.path.realpath(args.backward_input_file) + check_file_suffix(forward_file, FileCheckConst.JSON_SUFFIX) + check_file_suffix(backward_file, FileCheckConst.JSON_SUFFIX) + out_path = os.path.realpath(args.out_path) if args.out_path else "./" + out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) + out_path = out_path_checker.common_check() + save_error_data = args.save_error_data + if save_error_data: + initialize_save_error_data() + run_ut(forward_file, backward_file, out_path, save_error_data) + + +class UtDataInfo: + def __init__(self, bench_grad_out, device_grad_out, device_out, bench_out, grad_in, in_fwd_data_list): + self.bench_grad_out = bench_grad_out + self.device_grad_out = device_grad_out + self.device_out = device_out + self.bench_out = bench_out + self.grad_in = grad_in + self.in_fwd_data_list = in_fwd_data_list + + +if __name__ == '__main__': + _run_ut() + print_info_log("UT task completed.")