diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py index 30994f709444c4d479f5c807289be6e6bb58e25b..bca971116187d54e6a7a85eda75107b25b2ebb3e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -27,6 +27,8 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareC from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate +from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate +from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward from msprobe.core.common.file_check import FileOpen, FileChecker, \ @@ -78,6 +80,12 @@ def exec_api(api_type, api_name, args, kwargs): if api_type == "Torch": torch_api = TorchOPTemplate(api_name, str, False) out = torch_api.forward(*args, **kwargs) + if api_type == "Aten": + torch_api = AtenOPTemplate(api_name, None, False) + out = torch_api.forward(*args, **kwargs) + if api_type == "NPU": + torch_api = NpuOPTemplate(api_name, None, False) + out = torch_api.forward(*args, **kwargs) return out diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/__init__.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb06867371c6234583cabd485bcaa3dd671cb00c --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/__init__.py @@ -0,0 +1,15 @@ +import os +from pkgutil import iter_modules +from importlib import import_module + +""" +gpu and cpu not implement benchmark function, supplementary benchmarking function implementation +""" + +package_path = os.path.dirname(os.path.realpath(__file__)) +for _, module_name, _ in iter_modules([package_path]): + module = import_module(f"{__name__}.{module_name}") + for attr_name in dir(module): + attr = getattr(module, attr_name) + if callable(attr) and "npu_custom" not in attr_name: + globals()[attr_name] = attr diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0954911c19f35fd3e6aae1aa7a7cfea467ae5e --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py @@ -0,0 +1,31 @@ +import torch + +from msprobe.pytorch.function_factory import npu_custom_functions + + +@npu_custom_functions +def npu_apply_adam_w(beta1_power, beta2_power, lr, weight_decay, + beta1, beta2, eps, grad, max_grad_norm, amsgrad, maximize, out): + var, m, v = out + if amsgrad: + max_grad_norm = (torch.rand(var.shape) * 10.0 - 5.0).to(var.dtype) + beta1_power_out = beta1_power * beta1 + beta2_power_out = beta2_power * beta2 + var_t = var * (1 + (-lr * weight_decay)) + gt = -grad if maximize else grad + m_out = m * beta1 - (beta1 + (-1)) * gt + v_out = v * beta2 - (beta2 + (-1)) * gt * gt + + if amsgrad: + max_grad_norm_out = torch.max(max_grad_norm, v_out) + if (1 - beta2_power_out) == 0: + beta2_power_out -= eps + denom = torch.sqrt(torch.div(max_grad_norm_out, (1 - beta2_power_out))) + eps + else: + vraintain = torch.div(v_out, (1 - beta2_power_out)) + denom = torch.sqrt(vraintain) + eps + + if (1 - beta1_power_out) == 0: + beta1_power_out -= eps + var_out = var_t + torch.div(-lr * m_out, (1 - beta1_power_out)).div(denom) + return var_out.cpu(), m_out.cpu(), v_out.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..dd30bb18a6da92bda6466810136e24f6c45af7b7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py @@ -0,0 +1,25 @@ +import torch +from msprobe.pytorch.function_factory import npu_custom_functions, npu_custom_grad_functions + + +@npu_custom_functions +def npu_confusion_transpose(data, perm, shape, transpose_first): + if transpose_first: + output = data.permute(*perm).contiguous().view(shape) + else: + output = data.view(shape).permute(*perm) + return output.cpu() + + +@npu_custom_grad_functions +def npu_confusion_transpose_backward(grad, perm, shape, transpose_first): + shape_cal = shape if transpose_first else [shape[perm_dim] for perm_dim in perm] + perm_cal = [0] * len(perm) + for i, perm_dim in enumerate(perm): + perm_cal[perm_dim] = i + + if transpose_first: + result = grad.permute(*perm_cal).reshape(shape_cal) + else: + result = grad.reshape(shape_cal).permute(*perm_cal) + return result.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..5442eff734d40d8d6c3edee27d90d8dc85d37209 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py @@ -0,0 +1,58 @@ +import torch +from msprobe.pytorch.function_factory import npu_custom_functions, npu_custom_grad_functions + + +@npu_custom_functions +def fast_gelu(input0): + attr = 1.702 + const_0 = 0 - attr + const_1 = 1 + const_2 = attr / 2 + + abs_x = torch.abs(input0) + mul_abs_x = abs_x * const_0 + exp_abs_x = torch.exp(mul_abs_x) + div_down = exp_abs_x + const_1 + + pn_x = input0 - abs_x + mul_pn_x = pn_x * const_2 + exp_pn_x = torch.exp(mul_pn_x) + div_up = input0 * exp_pn_x + div_down_rec = torch.reciprocal(div_down) + result = div_up * div_down_rec + + return result.cpu() + + +@npu_custom_grad_functions +def npu_fast_gelu_backward(grad, input_x): + const_2 = 1.702 + const_3 = 1.0 + const_1 = 0.0 - const_2 + + # e^(-1.702x) + abs_x = torch.abs(input_x) + mul_abs_x = abs_x * const_1 + exp_x = torch.exp(mul_abs_x) + + # 1.702xe^(-1.702x) + add_2 = input_x * exp_x + add_2 = add_2 * const_2 + + # e^(1.702(x-|x|)) + pn_x = input_x - abs_x + mul_pn_x = pn_x * const_2 + exp_pn_x = torch.exp(mul_pn_x) + + # e^(-1.702x) + 1.702xe^(-1.702x) + e^(1.702(x-|x|)) + div_up = exp_x + add_2 + div_up = div_up + exp_pn_x + + # (e^(-1.702x)+1)^2 + div_down_i = exp_x + const_3 + div_down = div_down_i * div_down_i + div_down_rec = torch.reciprocal(div_down) + result_temp = div_up * div_down_rec + result = grad * result_temp + + return result.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..885b5c460edc3d734076994ee0d09dd48db27c56 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py @@ -0,0 +1,8 @@ +import torch +from msprobe.pytorch.function_factory import npu_custom_functions + + +@npu_custom_functions +def npu_layer_norm_eval(data, normalized_shape): + result = torch.nn.functional.layer_norm(data, normalized_shape) + return result.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..33b18d759d1bf1fa62525127b2998a38a0b530c4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py @@ -0,0 +1,15 @@ +import torch +from msprobe.pytorch.function_factory import npu_custom_functions, npu_custom_grad_functions + + +@npu_custom_functions +def npu_linear(x, weight, bias): + output = torch.nn.functional.linear(x, weight, bias) + return output.cpu() + + +@npu_custom_grad_functions +def npu_linear_backward(grad, input_data, weight): + input_grad = torch.matmul(grad, weight) + weight_grad = torch.matmul(grad.t(), input_data) + return input_grad.cpu(), weight_grad.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/matmul_backward.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/matmul_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4f7dc040f3857b417e27f4686102ac7e30a8d6 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/matmul_backward.py @@ -0,0 +1,50 @@ +import torch +from msprobe.pytorch.function_factory import npu_custom_grad_functions + + +@npu_custom_grad_functions +def matmul_backward(grad, self, other, mask): + grad_self, grad_other = None, None + dim_self = self.dim() + dim_other = other.dim() + + size_grad = list(grad.size()) + size_self = list(self.size()) + size_other = list(other.size()) + if dim_self == 1 and dim_other == 1: + grad_self = other.mul(grad) if mask[0] else grad_self + grad_other = self.mul(grad) if mask[1] else grad_other + elif dim_self == 2 and dim_other == 1: + grad_self = grad.unsqueeze(1).mm(other.unsqueeze(0)) if mask[0] else grad_self + grad_other = self.transpose(-1, -2).mm(grad.unsqueeze(1)).squeeze_(1) if mask[1] else grad_other + elif dim_self == 1 and dim_other == 2: + grad_self = grad.unsqueeze(0).mm(other.transpose(-1, -2)).squeeze_(0) if mask[0] else grad_self + grad_other = self.unsqueeze(1).mm(grad.unsqueeze(0)) if mask[1] else grad_other + elif dim_self >= 3 and (dim_other == 1 or dim_other == 2): + view_size = 1 if dim_other == 1 else size_grad[-1] + unfolded_grad = (grad.unsqueeze(-1) if dim_other == 1 else grad).contiguous().view(-1, view_size) + if mask[0]: + grad_self = unfolded_grad.mm(other.unsqueeze(0) if dim_other == 1 else other.transpose(-1, -2)) \ + .view(size_self) + if mask[1]: + unfolded_self = self.contiguous().view([-1, size_self[-1]]) + grad_other = unfolded_self.transpose(-1, -2).mm(unfolded_grad).view(size_other) + elif (dim_self == 1 or dim_self == 2) and dim_other >= 3: + view_size = 1 if dim_self == 1 else size_grad[-2] + unfolded_grad_T = grad.view([-1, view_size]) \ + if dim_self == 1 else grad.transpose(-1, -2).contiguous().view([-1, view_size]) + if mask[0]: + # create a 2D-matrix from other + unfolded_other_T = \ + other.transpose(-1, -2).contiguous().view([-1, size_other[-2]]).transpose(-1, -2) + grad_self = unfolded_other_T.mm(unfolded_grad_T).transpose(-1, -2).view(size_self) + if mask[1]: + size_other_T = size_other[:-2] + size_other_T.extend(size_other[::-1][:2]) + grad_other = \ + unfolded_grad_T.mm(self.unsqueeze(0) if dim_self == 1 else self).view(size_other_T).transpose(-1, -2) + else: + grad_self = torch.matmul(grad, other.transpose(-1, -2)) if mask[0] else grad_self + grad_other = torch.matmul(self.transpose(-1, -2), grad) if mask[1] else grad_other + + return grad_self.cpu(), grad_other.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..6a49ce740abd1550ccea88d83f07f8f1e29cdbff --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py @@ -0,0 +1,424 @@ +import torch +import numpy as np +from einops import rearrange + +from msprobe.pytorch.function_factory import npu_custom_functions, npu_custom_grad_functions +from api_accuracy_checker.common.utils import logger + +gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86 +softmax_build_mode = "QKV" # "MAX_SUM" + +""" +# 前向函数声明对比 +标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob +融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None, + atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, + next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0, + gen_mask_parallel=True, sync=False + +# 反向函数声明对比 +标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob +融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None, + atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None, + attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647, + next_tockens=2147483647, inner_precise=0, seed=0, offset=0, + numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False +""" + + +def softmax_forward(x): + x_max = torch.max(x, dim=-1, keepdims=True)[0] + x_sub = x.sub(x_max) + y = torch.exp(x_sub) + x_sum = y.sum(dim=-1, keepdims=True) + res = y.div(x_sum) + return res, x_max, x_sum + + +def softmax_grad(dp, softmax_res): + muls = dp * softmax_res + muls_r = muls.sum(dim=-1, keepdims=True) + sub_r = dp - muls_r + res = sub_r * softmax_res + return res + + +def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype): + if num_kv_heads == 0 or num_kv_heads < num_heads: + raise ValueError(f"num_kv_heads must be non-zero and less than num_heads.") + + factor = num_heads // num_kv_heads + kv_shape = kv_tensor.shape + B = kv_shape[0] + S = kv_shape[2] + D = kv_shape[3] + kv_res = torch.zeros([B, num_heads, S, D]).to(dtype) + for i in range(num_heads): + j = i // factor + kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :] + return kv_res + + +def calculate_qk(q, k, atten_mask, pse, scale): + if pse is None or len(pse.shape) == 0: + qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scale) + else: + qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scale) + if atten_mask is None or len(atten_mask.shape) == 0: + return qk + else: + qk = qk + atten_mask.bool() * (-40000.0) # -10000 + return qk + + +def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_prob): + qk = calculate_qk(q, k, atten_mask, pse, scale) + softmax_res, softmax_max, softmax_sum = softmax_forward(qk) + if drop_mask is None or len(drop_mask.shape) == 0: + drop_res = softmax_res + else: + drop_res = softmax_res * drop_mask * (1.0 / keep_prob) + y = torch.matmul(drop_res, v) + return y, softmax_max, softmax_sum + + +def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob): + dp = torch.matmul(dx, v.permute(0, 1, 3, 2)) + if drop_mask is None or len(drop_mask.shape) == 0: + drop_res = softmax_res.permute(0, 1, 3, 2) + dp_drop = dp + else: + drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2) + dp_drop = dp * drop_mask * (1.0 / keep_prob) + dv = torch.matmul(drop_res, dx) + softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scale) + dq = torch.matmul(softmax_grad_res, k) + dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q) + return dq, dk, dv + + +def parse_bsnd_args(query, key, head_num, input_layout): + supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"] + B, S1, S2, N1, N2, D, H1, H2 = None, None, None, head_num, None, None, None, None + + if not isinstance(input_layout, str) or input_layout not in supported_input_layout: + raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.") + + if input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + try: + if input_layout == "BSH": + B, S1, H1 = query.shape + _, S2, H2 = key.shape + D = H1 // N1 + N2 = H2 // D + elif input_layout == "SBH": + S1, B, H1 = query.shape + S2, _, H2 = key.shape + D = H1 // N1 + N2 = H2 // D + elif input_layout == "BSND": + B, S1, N1, D = query.shape + _, S2, N2, _ = key.shape + H1 = N1 * D + H2 = N2 * D + elif input_layout == "BNSD": + B, N1, S1, D = query.shape + _, N2, S2, _ = key.shape + H1 = N1 * D + H2 = N2 * D + except Exception as e: + raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e + + if D == 0: + raise ValueError(f"Value D must be non-zero.") + DTYPE = query.dtype + return B, S1, S2, N1, N2, D, H1, H2, DTYPE + + +def convert_from_bnsd(_input, input_layout): + if input_layout == "BSH": + # (B,N,S,D)=>(B,S,N*D) + out = rearrange(_input, 'b n s d -> b s (n d)').contiguous() + elif input_layout == "SBH": + # (B,N,S,D)=>(S,B,N*D) + out = rearrange(_input, 'b n s d -> s b (n d)').contiguous() + elif input_layout == "BSND": + # (B,N,S,D)=>(B,S,N,D) + out = rearrange(_input, 'b n s d -> b s n d').contiguous() + elif input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + else: + out = _input + return out + + +def convert_to_bnsd(_input, n, input_layout): + # 默认"BNSD"无需处理 + if input_layout == "BSH": + # (B,S,N*D)=>(B,N,S,D) + out = rearrange(_input, 'b s (n d) -> b n s d', n=n) + elif input_layout == "SBH": + # (S,B,N*D)=>(B,N,S,D) + out = rearrange(_input, 's b (n d) -> b n s d', n=n) + elif input_layout == "BSND": + # (B,S,N,D)=>(B,N,S,D) + out = rearrange(_input, 'b s n d -> b n s d', n=n) + elif input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + else: + out = _input + if out.dim() != 4: + raise ValueError(f"convert qkv format failed with input_layout {input_layout}.") + return out.to(gtype) + + +def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next_tocken, dtype): + """ + # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现 + ===> atten_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype) + """ + shape = [S1, S2] + + if atten_mask is not None: + # 当FA的输入已经包含atten_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原 + if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4: + logger.info(f"S1: {S1}, S2:{S2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}") + + if atten_mask.dim() == 2 and atten_mask.shape[0] == 2048 and atten_mask.shape[1] == 2048: + if atten_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(atten_mask.dtype)): + if sparse_mode == 2: + atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1)) + elif sparse_mode == 3: + atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1)) + elif sparse_mode == 4: + atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + atten_mask = atten_mask_u + atten_mask_l + logger.debug(f"反向转换atten_mask {atten_mask.shape}") + return atten_mask.to(dtype) + + return atten_mask.to(dtype) + + if atten_mask is not None: + if atten_mask.dim() == 2: + if atten_mask.shape[0] != S1 or atten_mask.shape[1] != S2: + raise ValueError(f"Invalid atten_mask shape `SS` {atten_mask.shape}") + shape = [S1, S2] + elif atten_mask.dim() == 4: + if atten_mask.shape[1] == 1: + shape = [B, 1, S1, S2] if B != 1 else [1, 1, S1, S2] + else: + shape = [B, N1, S1, S2] if B != 1 else [1, N1, S1, S2] + + if sparse_mode == 0: + atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + atten_mask = atten_mask_u + atten_mask_l + elif sparse_mode == 1: # no sparse + atten_mask = torch.from_numpy(np.zeros(shape)) + elif sparse_mode == 2: + atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1)) + elif sparse_mode == 3: + atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1)) + elif sparse_mode == 4: + atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + atten_mask = atten_mask_u + atten_mask_l + # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入atten_mask,且atten_mask矩阵数据格式须为BNSS或B1SS, + # 因此可以认为FA的输入已经是正确的atten_mask了 + return atten_mask.to(dtype) + + +def generate_kv(key, value, N1, N2): + # N不等长适配by cdy + if not (N1 == N2): + k_new = broadcast_kv(N1, N2, key, key.dtype) + v_new = broadcast_kv(N1, N2, value, value.dtype) + else: + k_new = key + v_new = value + return k_new, v_new + + +def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale): + """ + attention = softmax(QK^T/sqrt(d))V + softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max)) + """ + logger.info("Using QKV to rebuild original softmax") + qk = calculate_qk(q, k, atten_mask, pse, scale) + softmax_res, x_max, x_sum = softmax_forward(qk) + return softmax_res + + +def rebuild_softmax_by_max_sum(q, k, atten_mask, pse, scale, softmax_max, softmax_sum): + """ + attention = softmax(QK^T/sqrt(d))V + softmax(x_i) = e^(x_i - x_max_i) / x_sum_i) + """ + logger.info("Using softmax_max and softmax_sum to rebuild original softmax") + qk = calculate_qk(q, k, atten_mask, pse, scale) + if softmax_max.shape[-1] == 0: + raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}") + repeat_dim = qk.shape[-1] // softmax_max.shape[-1] + softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div( + softmax_sum.repeat(1, 1, 1, repeat_dim)) + return softmax_res + + +def npu_fusion_attention_forward_patch(*args, **kwargs): + # query, key, value, head_num, input_layout + if len(args) != 5: + raise ValueError(f"Unsupported npu_fusion_attention args {args}.") + + B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[3], args[4]) + if N1 == N2 and S1 == S2: + logger.debug(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}") + else: + logger.debug(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}") + if not (N1 % N2 == 0 and N1 >= N2): + raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.") + + dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2, + "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE} + + new_kwargs = {"keep_prob": 1, + "scale": kwargs.get("scale", 1 / (D ** 0.5)), + "sparse_mode": kwargs.get("sparse_mode", 0), + "prefix": kwargs.get("prefix"), + "pre_tockens": kwargs.get("pre_tockens", 2147483647), + "next_tockens": kwargs.get("next_tockens", 2147483647), + "pse": kwargs.get("pse"), + "padding_mask": kwargs.get("padding_mask"), + "atten_mask": kwargs.get("atten_mask")} + + return args, dims_kwargs, new_kwargs + + +def npu_fusion_attention_backward_patch(*args, **kwargs): + if len(args) != 6: + raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.") + + B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[4], args[5]) + if N1 == N2 and S1 == S2: + logger.info(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}") + else: + logger.info(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}") + if not (N1 % N2 == 0 and N1 >= N2): + raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.") + + dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2, + "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE} + + new_kwargs = {"keep_prob": 1, + "scale_value": kwargs.get("scale_value", 1 / (D ** 0.5)), + "sparse_mode": kwargs.get("sparse_mode", 0), + "prefix": kwargs.get("prefix"), + "pre_tockens": kwargs.get("pre_tockens", 2147483647), + "next_tockens": kwargs.get("next_tockens", 2147483647), + "pse": kwargs.get("pse"), + "padding_mask": kwargs.get("padding_mask"), + "softmax_max": kwargs.get("softmax_max"), + "softmax_sum": kwargs.get("softmax_sum"), + "softmax_in": kwargs.get("softmax_in"), + "attention_in": kwargs.get("attention_in"), + "seed": kwargs.get("seed", 0), + "offset": kwargs.get("offset", 0), + "numels": kwargs.get("numels", 0), + "atten_mask": kwargs.get("atten_mask")} + + return args, dims_kwargs, new_kwargs + + +@npu_custom_functions +def npu_fusion_attention(*args, **kwargs): + new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs) + query, key, value, input_layout = new_args[0], new_args[1], new_args[2], new_args[4] + N1 = dims_kwargs.get("N1") + N2 = dims_kwargs.get("N2") + S1 = dims_kwargs.get("S1") + S2 = dims_kwargs.get("S2") + B = dims_kwargs.get("B") + DTYPE = dims_kwargs.get("DTYPE") + atten_mask = new_kwargs.get("atten_mask") + keep_prob = new_kwargs.get("keep_prob") + sparse_mode = new_kwargs.get("sparse_mode") + pre_tockens = new_kwargs.get("pre_tockens") + next_tockens = new_kwargs.get("next_tockens") + pse = new_kwargs.get("pse") + scale = new_kwargs.get("scale") + + atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE) + query = convert_to_bnsd(query, N1, input_layout) + key = convert_to_bnsd(key, N2, input_layout) + value = convert_to_bnsd(value, N2, input_layout) + k_new, v_new = generate_kv(key, value, N1, N2) + out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new, + drop_mask=None, atten_mask=atten_mask, + pse=pse, scale=scale, + keep_prob=keep_prob) + if out_golden.dim() == 5: + out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3), + out_golden.size(4)) + out_golden = convert_from_bnsd(out_golden, input_layout) + + return out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu() + + +@npu_custom_grad_functions +def npu_fusion_attention_grad(*args, **kwargs): + # dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob + new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*args, **kwargs) + query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5] + N1 = dims_kwargs.get("N1") + N2 = dims_kwargs.get("N2") + S1 = dims_kwargs.get("S1") + S2 = dims_kwargs.get("S2") + B = dims_kwargs.get("B") + D = dims_kwargs.get("D") + DTYPE = dims_kwargs.get("DTYPE") + atten_mask = new_kwargs.get("atten_mask") + keep_prob = new_kwargs.get("keep_prob") + sparse_mode = new_kwargs.get("sparse_mode") + pre_tockens = new_kwargs.get("pre_tockens") + next_tockens = new_kwargs.get("next_tockens") + pse = new_kwargs.get("pse") + softmax_max = new_kwargs.get("softmax_max") + softmax_sum = new_kwargs.get("softmax_sum") + scale_value = new_kwargs.get("scale_value") + + atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE) + query = convert_to_bnsd(query, N1, input_layout) + dx = convert_to_bnsd(dx, N1, input_layout) + key = convert_to_bnsd(key, N2, input_layout) + value = convert_to_bnsd(value, N2, input_layout) + k_new, v_new = generate_kv(key, value, N1, N2) + + if softmax_build_mode == "QKV": + softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value) + else: + softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum) + + dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob) + + # N不等长适配by cdy + if not (N1 == N2): + if N2 == 0: + raise ValueError("dims_kwargs.N2 must be non-zero.") + G = int(N1 / N2) + dk = torch.sum(dk.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D) + dv = torch.sum(dv.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D) + + if dq.dim() == 5: + dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4)) + if dk.dim() == 5: + dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4)) + if dv.dim() == 5: + dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4)) + + dq = convert_from_bnsd(dq, input_layout) + dk = convert_from_bnsd(dk, input_layout) + dv = convert_from_bnsd(dv, input_layout) + + return dq.cpu(), dk.cpu(), dv.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe6c834a4c380d50cca37802bc62552d7c95eaa --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py @@ -0,0 +1,18 @@ +import torch +from msprobe.pytorch.function_factory import npu_custom_functions, npu_custom_grad_functions + + +@npu_custom_functions +def npu_rms_norm(x, gamma, epsilon=1e-5): + rstd = torch.rsqrt(torch.mean(torch.pow(x, 2), axis=-1, keepdim=True) + epsilon) + res = x * rstd * gamma + return res.cpu(), rstd.float().cpu() + + +@npu_custom_grad_functions +def npu_rms_norm_backward(grad, x, gamma, rstd): + mean_gy = (grad * x * gamma * rstd).mean(dim=-1, keepdim=True) + grad_x = (grad * gamma - x * rstd * mean_gy) * rstd + grad_gamma = x * grad * rstd + return grad_x.cpu(), grad_gamma.cpu() + diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py new file mode 100644 index 0000000000000000000000000000000000000000..76b3828da3d988476ef71760006e1d32ffcece94 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py @@ -0,0 +1,55 @@ +import torch +from msprobe.pytorch.function_factory import npu_custom_functions, npu_custom_grad_functions + + +@npu_custom_functions +def npu_rotary_mul(x, r1, r2): + x1, x2 = torch.chunk(x, 2, -1) + x_new = torch.cat((-x2, x1), dim=-1) + output = r1 * x + r2 * x_new + return output.cpu() + + +@npu_custom_grad_functions +def npu_rotary_mul_backward(dy_tensor, x, r1, r2): + x.requires_grad = True + r1.requires_grad = True + r2.requires_grad = True + # golden + x1, x2 = torch.chunk(x, 2, -1) + x_new = torch.cat((-x2, x1), dim=-1) + golden_tensor = r1 * x + r2 * x_new + golden_tensor.backward(dy_tensor) + r1_shape = r1.shape + r1_grad = torch.zeros(r1_shape).type(torch.float32) + r2_grad = torch.zeros(r1_shape).type(torch.float32) + x1, x2 = torch.chunk(x.float(), 2, -1) + x_new2 = torch.cat((-x2, x1), dim=-1) + x_shape = x.shape + h = x.float() + grad = dy_tensor.float() + condition_1 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and + ((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and + (r1_shape[1] == x_shape[1]) and (r1_shape[3] == x_shape[3])) + condition_2 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and + ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and + (r1_shape[2] == x_shape[2]) and (r1_shape[3] == x_shape[3])) + condition_3 = (((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and + ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and + (r1_shape[0] == x_shape[0]) and (r1_shape[3] == x_shape[3])) + if condition_1: + for i in range(x_shape[0]): + for j in range(x_shape[2]): + r2_grad[0, :, 0, :] += (x_new2[i, :, j, :] * grad[i, :, j, :]) + r1_grad[0, :, 0, :] += (h[i, :, j, :] * grad[i, :, j, :]) + elif condition_2: + for i in range(x_shape[0]): + for j in range(x_shape[1]): + r2_grad[0, 0, :, :] += (x_new2[i, j, :, :] * grad[i, j, :, :]) + r1_grad[0, 0, :, :] += (h[i, j, :, :] * grad[i, j, :, :]) + elif condition_3: + for i in range(x_shape[1]): + for j in range(x_shape[2]): + r2_grad[:, 0, 0, :] += (x_new2[:, i, j, :] * grad[:, i, j, :]) + r1_grad[:, 0, 0, :] += (h[:, i, j, :] * grad[:, i, j, :]) + return x.grad.cpu(), r1_grad.cpu(), r2_grad.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc523ee40fd4910d4a43f79cf8f7b2b9cbe2f8d --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py @@ -0,0 +1,29 @@ +import torch +from msprobe.pytorch.function_factory import npu_custom_functions, npu_custom_grad_functions + + +@npu_custom_functions +def npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask): + if fixed_triu_mask: + mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device) + dtype = x.dtype + x = (x * scale).masked_fill(mask, value=-10000) + x = x - torch.max(x, dim=-1, keepdims=True)[0] + x = torch.exp(x.float()) + y = torch.div(x, torch.sum(x, dim=-1, keepdims=True)) + return y.to(dtype).cpu() + + +@npu_custom_grad_functions +def npu_scaled_masked_softmax_backward(y_grad, y, mask, scale, fixed_triu_mask): + if fixed_triu_mask: + mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device) + dtype = y_grad.dtype + y_grad = y_grad.float() + y = y.float() + x_grad = y_grad * y + x_grad = y_grad - torch.sum(x_grad, dim=-1, keepdims=True) + x_grad = x_grad * y + x_grad = x_grad * scale + x_grad = x_grad.masked_fill(mask, value=0) + return x_grad.to(dtype).cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py new file mode 100644 index 0000000000000000000000000000000000000000..973be454d3e606e0c70e1c99e5ed901f3164d09d --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py @@ -0,0 +1,58 @@ +import torch +from msprobe.pytorch.function_factory import npu_custom_functions, npu_custom_grad_functions + + +@npu_custom_functions +def npu_swiglu(x, dim=-1): + tensor_dtype = x.dtype + + inTensors = torch.chunk(x, 2, dim=dim) + if tensor_dtype == torch.float32: + tensor_scalar = torch.sigmoid(torch.mul(inTensors[0], 1.0)) + output_data = torch.mul(torch.mul(tensor_scalar, inTensors[0]), inTensors[1]) + else: + tensor_self_float = inTensors[0].type(torch.float) + tensor_other_float = inTensors[1].type(torch.float) + tensor_out_float = torch.nn.functional.silu(tensor_self_float).type(tensor_dtype).type( + torch.float32) * tensor_other_float + output_data = tensor_out_float.type(tensor_dtype) + return output_data.cpu() + + +@npu_custom_grad_functions +def npu_swiglu_backward(grad, x, dim=-1): + tensor_dtype = grad.dtype + in_tensors = torch.chunk(x, 2, dim=dim) + tensor_grad_out = grad + + if tensor_dtype == torch.float16: + tensor_out1 = torch.mul( + torch.mul(in_tensors[1].type(torch.float32), swish_grad(1, in_tensors[0].type(torch.float32))), + tensor_grad_out.type(torch.float32)).type(torch.float16) + tensor_out2 = torch.mul(tensor_grad_out.type(torch.float32), + swish(1, in_tensors[0].type(torch.float32))).type(torch.float16) + output = torch.cat((tensor_out1, tensor_out2), dim) + elif tensor_dtype == torch.bfloat16: + tensor_self_float = in_tensors[0].type(torch.float) + tensor_other_float = in_tensors[1].type(torch.float) + tensor_gradout_float = tensor_grad_out.type(torch.float) + + tensor_out1 = torch.mul(tensor_gradout_float, swish_grad(1.0, tensor_self_float)).type(torch.bfloat16).type( + torch.float32) * tensor_other_float + tensor_out2 = swish(1.0, tensor_self_float).type(torch.bfloat16).type(torch.float32) * tensor_gradout_float + tensor_out_float = torch.cat((tensor_out1, tensor_out2), dim=dim) + output = tensor_out_float.type(torch.bfloat16) + else: + tensor_out1 = torch.mul(torch.mul(in_tensors[1], swish_grad(1.0, in_tensors[0])), tensor_grad_out) + tensor_out2 = torch.mul(tensor_grad_out, swish(1.0, in_tensors[0])) + output = torch.cat((tensor_out1, tensor_out2), dim) + return output.cpu() + + +def swish_grad(beta, x): + return torch.sigmoid(beta * x) + x * (1 - torch.sigmoid(beta * x)) * torch.sigmoid(beta * x) * beta + + +def swish(beta, x): + return x * torch.sigmoid(beta * x) + diff --git a/debug/accuracy_tools/msprobe/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py index acc1de105148899a46fe381b94cb7772f26d3ac8..181491488f91049148c51827387edc53d21d41cf 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py @@ -14,10 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """ +import logging import os import random import stat import torch +import torch.distributed as dist import numpy as np from functools import wraps from msprobe.core.common.exceptions import DistributedNotInitializedError @@ -221,3 +223,36 @@ class Const: CONVERT_API = { "int32_to_int64": ["cross_entropy"] } + + +def get_tensor_rank(in_feat, out_feat): + if dist.is_initialized(): + return dist.get_rank() + + def get_tensor_rank_single(x): + if isinstance(x, (list, tuple)): + if len(x) > 0: + return get_tensor_rank_single(x[0]) + elif isinstance(x, torch.Tensor): + device = x.device + if device.type != 'cpu': + return device.index + return None + + in_rank = get_tensor_rank_single(in_feat) + out_rank = get_tensor_rank_single(out_feat) + tensor_rank = in_rank if in_rank else out_rank + return tensor_rank + + +def _create_logger(level=logging.INFO): + logger_ = logging.getLogger() + logger_.setLevel(level) + ch = logging.StreamHandler() + ch.setLevel(level) + logger_.addHandler(ch) + return logger_ + + +log_level = logging.DEBUG if os.environ.get("API_ACCURACY_CHECK_LOG_LEVEL") == "1" else logging.INFO +logger = _create_logger(log_level) diff --git a/debug/accuracy_tools/msprobe/pytorch/function_factory.py b/debug/accuracy_tools/msprobe/pytorch/function_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..6d840e561ea083c672fcd527432aa0ebfe1dee6b --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/function_factory.py @@ -0,0 +1,66 @@ +from msprobe.pytorch.common.utils import logger + + +class Register(dict): + def __init__(self, *args, **kwargs): + super(Register, self).__init__(*args, **kwargs) + self._dict = {} + + def __call__(self, target): + return self.register(target) + + def __setitem__(self, key, value): + self._dict[key] = value + + def __getitem__(self, key): + return self._dict[key] + + def __contains__(self, key): + return key in self._dict + + def __str__(self): + return str(self._dict) + + def keys(self): + return self._dict.keys() + + def values(self): + return self._dict.values() + + def items(self): + return self._dict.items() + + def register(self, target): + + def add_register_item(key, value): + if key in self._dict: + logger.warning(f"{value.__name__} has been registered before, so we will overriden it.") + self[key] = value + return value + + if callable(target): + return add_register_item(target.__name__, target) + else: + raise Exception(f"The func {target} is not callable.") + + +npu_custom_functions = Register() +npu_custom_grad_functions = Register() + +from msprobe.pytorch.bench_functions.apply_adam_w import npu_apply_adam_w +from msprobe.pytorch.bench_functions.confusion_transpose import npu_confusion_transpose, \ + npu_confusion_transpose_backward +from msprobe.pytorch.bench_functions.fast_gelu import fast_gelu, npu_fast_gelu_backward +from msprobe.pytorch.bench_functions.layer_norm_eval import npu_layer_norm_eval +from msprobe.pytorch.bench_functions.linear import npu_linear, npu_linear_backward +from msprobe.pytorch.bench_functions.matmul_backward import matmul_backward +from msprobe.pytorch.bench_functions.npu_fusion_attention import softmax_forward, softmax_grad, broadcast_kv, \ + calculate_qk, fusion_attention_forward, fusion_attention_backward, parse_bsnd_args, convert_from_bnsd, \ + convert_to_bnsd, generate_atten_mask, generate_kv, rebuid_softmax_by_qkv, rebuild_softmax_by_max_sum, \ + npu_fusion_attention_forward_patch, npu_fusion_attention_backward_patch, npu_fusion_attention, \ + npu_fusion_attention_grad +from msprobe.pytorch.bench_functions.rms_norm import npu_rms_norm, npu_rms_norm_backward +from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotary_mul_backward +from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \ + npu_scaled_masked_softmax_backward +from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward, swish_grad, swish diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py index 4617e4854fcbb3b7ac60536886b74387cb01d99b..a02abbe5f4b7e551faf2c4ff465271ae9bebffde 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py @@ -24,12 +24,14 @@ from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.common.utils import torch_device_guard from msprobe.core.common.const import Const from msprobe.core.common.file_check import FileOpen - +from msprobe.pytorch.function_factory import npu_custom_grad_functions cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") with FileOpen(yaml_path, 'r') as f: - WrapAtenOps = yaml.safe_load(f).get('aten') + Ops = yaml.safe_load(f) + WrapAtenOps = Ops.get('aten') + WhiteAtenOps = Ops.get('white_aten_ops', []) aten_func = {} @@ -48,7 +50,7 @@ class HOOKAtenOP(object): class AtenOPTemplate(HOOKModule): - def __init__(self, op, hook): + def __init__(self, op, hook, need_hook=True): if isinstance(op, torch._ops.OpOverloadPacket): op_name_ = op._qualified_op_name.split("::")[-1] else: @@ -58,10 +60,21 @@ class AtenOPTemplate(HOOKModule): op_name_ = op_name_ + '.' + overload_name self.op = op self.prefix_op_name_ = "Aten" + Const.SEP + str(op_name_) + Const.SEP - super().__init__(hook) + self.need_hook = need_hook + if self.need_hook: + super().__init__(hook) @torch_device_guard def forward(self, *args, **kwargs): + if isinstance(self.op, str): + if self.op in npu_custom_grad_functions: + return npu_custom_grad_functions[self.op](*args, **kwargs) + if self.op in WhiteAtenOps: + return eval(f"torch.ops.aten.{self.op}")(*args, **kwargs) + if self.op not in aten_func: + raise Exception(f"Skip op[{self.op}] accuracy check, because the op is not " + f"in dir(torch.ops.aten) and support yaml.") + return aten_func[self.op](*args, **kwargs) return self.op(*args, **kwargs) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py index 992713bce57b0f73b5d7144f4fb3a04726f70468..8a67ed94290d9ba02e947f5806daece13f041e9e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py @@ -17,19 +17,26 @@ import os import torch -import torch_npu import yaml from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version from msprobe.core.common.const import Const from msprobe.core.common.file_check import FileOpen +from msprobe.pytorch.function_factory import npu_custom_functions cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") with FileOpen(yaml_path, 'r') as f: WrapNpuOps = yaml.safe_load(f).get('torch_npu') +try: + import torch_npu +except ImportError: + is_gpu = True +else: + is_gpu = False + def get_npu_ops(): global WrapNpuOps @@ -46,13 +53,19 @@ class HOOKNpuOP(object): class NpuOPTemplate(HOOKModule): - def __init__(self, op_name, hook): + def __init__(self, op_name, hook, need_hook=True): self.op_name_ = op_name self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP - super().__init__(hook) + self.need_hook = need_hook + if need_hook: + super().__init__(hook) @torch_device_guard def forward(self, *args, **kwargs): + if not self.need_hook: + if self.op_name_ not in npu_custom_functions: + raise Exception(f'There is not bench function {self.op_name_}') + return npu_custom_functions[self.op_name_](*args, **kwargs) if torch_without_guard_version: return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs) else: @@ -60,7 +73,6 @@ class NpuOPTemplate(HOOKModule): def wrap_npu_op(op_name, hook): - def npu_op_template(*args, **kwargs): return NpuOPTemplate(op_name, hook)(*args, **kwargs)