From 748bdca629821f31637aa0d711c020e6250868ef Mon Sep 17 00:00:00 2001 From: yang-minghai22 Date: Fri, 26 Apr 2024 10:23:01 +0800 Subject: [PATCH] support tcp/ip online api check --- .../bench_functions/__init__.py | 11 + .../bench_functions/apply_adam_w.py | 26 + .../bench_functions/confusion_transpose.py | 25 + .../bench_functions/fast_gelu.py | 58 ++ .../bench_functions/layer_norm_eval.py | 8 + .../bench_functions/linear.py | 15 + .../bench_functions/matmul_backward.py | 51 ++ .../bench_functions/rms_norm.py | 18 + .../bench_functions/rotary_mul.py | 52 ++ .../bench_functions/scaled_mask_softmax.py | 29 + .../bench_functions/swiglu.py | 58 ++ .../api_accuracy_checker/common/config.py | 26 +- .../common/function_factory.py | 47 ++ .../api_accuracy_checker/common/utils.py | 41 + .../api_accuracy_checker/compare/compare.py | 84 +- .../api_accuracy_checker/config.yaml | 5 + .../api_accuracy_checker/dump/dispatch.py | 86 ++ .../api_accuracy_checker/dump/dump.py | 49 +- .../dump/torch_ops_config.yaml | 60 ++ .../hook_module/register_hook.py | 14 + .../hook_module/support_wrap_ops.yaml | 743 ++++++++++++++++++ .../api_accuracy_checker/hook_module/utils.py | 5 +- .../hook_module/wrap_aten.py | 88 +++ .../hook_module/wrap_npu_custom.py | 83 ++ .../api_accuracy_checker/run_ut/run_ut.py | 120 ++- .../tensor_transport_layer/attl.py | 150 ++++ .../tensor_transport_layer/client.py | 309 ++++++++ .../tensor_transport_layer/device_dispatch.py | 104 +++ .../tensor_transport_layer/server.py | 216 +++++ .../test/ut/compare/test_compare.py | 4 +- .../test/ut/run_ut/test_run_ut.py | 8 +- 31 files changed, 2523 insertions(+), 70 deletions(-) create mode 100644 debug/accuracy_tools/api_accuracy_checker/bench_functions/__init__.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/bench_functions/apply_adam_w.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/bench_functions/confusion_transpose.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/bench_functions/fast_gelu.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/bench_functions/layer_norm_eval.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/bench_functions/linear.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/bench_functions/matmul_backward.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/bench_functions/rms_norm.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/bench_functions/rotary_mul.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/bench_functions/scaled_mask_softmax.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/bench_functions/swiglu.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/common/function_factory.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/dump/dispatch.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/dump/torch_ops_config.yaml create mode 100644 debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_aten.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_npu_custom.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/attl.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/client.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/device_dispatch.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/server.py diff --git a/debug/accuracy_tools/api_accuracy_checker/bench_functions/__init__.py b/debug/accuracy_tools/api_accuracy_checker/bench_functions/__init__.py new file mode 100644 index 000000000..69b0de0d7 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/bench_functions/__init__.py @@ -0,0 +1,11 @@ +import os +from pkgutil import iter_modules +from importlib import import_module + +package_path = os.path.dirname(os.path.realpath(__file__)) +for _, module_name, _ in iter_modules([package_path]): + module = import_module(f"{__name__}.{module_name}") + for attr_name in dir(module): + attr = getattr(module, attr_name) + if callable(attr) and "npu_custom" not in attr_name: + globals()[attr_name] = attr diff --git a/debug/accuracy_tools/api_accuracy_checker/bench_functions/apply_adam_w.py b/debug/accuracy_tools/api_accuracy_checker/bench_functions/apply_adam_w.py new file mode 100644 index 000000000..4772fe21b --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/bench_functions/apply_adam_w.py @@ -0,0 +1,26 @@ +import torch +from api_accuracy_checker.common.function_factory import npu_custom_functions + + +@npu_custom_functions +def npu_apply_adam_w(beta1_power, beta2_power, lr, weight_decay, + beta1, beta2, eps, grad, max_grad_norm, amsgrad, maximize, out): + var, m, v = out + if amsgrad: + max_grad_norm = (torch.rand(var.shape) * 10.0 - 5.0).to(var.dtype) + gt = -grad if maximize else grad + m_out = m * beta1 - (beta1 + (-1)) * gt + v_out = v * beta2 - (beta2 + (-1)) * gt * gt + var_t = var * (1 + (-lr * weight_decay)) + beta1_power_out = beta1_power * beta1 + beta2_power_out = beta2_power * beta2 + if amsgrad: + max_grad_norm_out = torch.max(max_grad_norm, v_out) + denom = torch.sqrt(torch.div(max_grad_norm_out, (1 - beta2_power_out))) + eps + else: + vraintain = torch.div(v_out, (1 - beta2_power_out)) + denom = torch.sqrt(vraintain) + eps + + var_out = var_t + torch.div(-lr * m_out, (1 - beta1_power_out)).div(denom) + return var_out, m_out, v_out + diff --git a/debug/accuracy_tools/api_accuracy_checker/bench_functions/confusion_transpose.py b/debug/accuracy_tools/api_accuracy_checker/bench_functions/confusion_transpose.py new file mode 100644 index 000000000..d7323f272 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/bench_functions/confusion_transpose.py @@ -0,0 +1,25 @@ +import torch +from api_accuracy_checker.common.function_factory import npu_custom_functions, npu_custom_grad_functions + + +@npu_custom_functions +def npu_confusion_transpose(data, perm, shape, transpose_first): + if transpose_first: + output = data.permute(*perm).contiguous().view(shape) + else: + output = data.view(shape).permute(*perm) + return output + + +@npu_custom_grad_functions +def npu_confusion_transpose_backward(grad, perm, shape, transpose_first): + shape_cal = shape if transpose_first else [shape[perm_dim] for perm_dim in perm] + perm_cal = [0] * len(perm) + for i, perm_dim in enumerate(perm): + perm_cal[perm_dim] = i + + if transpose_first: + result = grad.permute(*perm_cal).reshape(shape_cal) + else: + result = grad.reshape(shape_cal).permute(*perm_cal) + return result diff --git a/debug/accuracy_tools/api_accuracy_checker/bench_functions/fast_gelu.py b/debug/accuracy_tools/api_accuracy_checker/bench_functions/fast_gelu.py new file mode 100644 index 000000000..3a04a24b6 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/bench_functions/fast_gelu.py @@ -0,0 +1,58 @@ +import torch +from api_accuracy_checker.common.function_factory import npu_custom_functions, npu_custom_grad_functions + + +@npu_custom_functions +def fast_gelu(input0): + attr = 1.702 + const_0 = 0 - attr + const_1 = 1 + const_2 = attr / 2 + + abs_x = torch.abs(input0) + mul_abs_x = abs_x * const_0 + exp_abs_x = torch.exp(mul_abs_x) + div_down = exp_abs_x + const_1 + + pn_x = input0 - abs_x + mul_pn_x = pn_x * const_2 + exp_pn_x = torch.exp(mul_pn_x) + div_up = input0 * exp_pn_x + div_down_rec = torch.reciprocal(div_down) + result = div_up * div_down_rec + + return result + + +@npu_custom_grad_functions +def npu_fast_gelu_backward(grad, input_x): + const_2 = 1.702 + const_3 = 1.0 + const_1 = 0.0 - const_2 + + # e^(-1.702x) + abs_x = torch.abs(input_x) + mul_abs_x = abs_x * const_1 + exp_x = torch.exp(mul_abs_x) + + # 1.702xe^(-1.702x) + add_2 = input_x * exp_x + add_2 = add_2 * const_2 + + # e^(1.702(x-|x|)) + pn_x = input_x - abs_x + mul_pn_x = pn_x * const_2 + exp_pn_x = torch.exp(mul_pn_x) + + # e^(-1.702x) + 1.702xe^(-1.702x) + e^(1.702(x-|x|)) + div_up = exp_x + add_2 + div_up = div_up + exp_pn_x + + # (e^(-1.702x)+1)^2 + div_down_i = exp_x + const_3 + div_down = div_down_i * div_down_i + div_down_rec = torch.reciprocal(div_down) + result_temp = div_up * div_down_rec + result = grad * result_temp + + return result diff --git a/debug/accuracy_tools/api_accuracy_checker/bench_functions/layer_norm_eval.py b/debug/accuracy_tools/api_accuracy_checker/bench_functions/layer_norm_eval.py new file mode 100644 index 000000000..f257c3566 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/bench_functions/layer_norm_eval.py @@ -0,0 +1,8 @@ +import torch +from api_accuracy_checker.common.function_factory import npu_custom_functions + + +@npu_custom_functions +def npu_layer_norm_eval(data, normalized_shape): + result = torch.nn.functional.layer_norm(data, normalized_shape) + return result diff --git a/debug/accuracy_tools/api_accuracy_checker/bench_functions/linear.py b/debug/accuracy_tools/api_accuracy_checker/bench_functions/linear.py new file mode 100644 index 000000000..48bb08c4f --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/bench_functions/linear.py @@ -0,0 +1,15 @@ +import torch +from api_accuracy_checker.common.function_factory import npu_custom_functions, npu_custom_grad_functions + + +@npu_custom_functions +def npu_linear(x, weight, bias): + output = torch.nn.functional.linear(x, weight, bias) + return output + + +@npu_custom_grad_functions +def npu_linear_backward(grad, input_data, weight): + input_grad = torch.matmul(grad, weight) + weight_grad = torch.matmul(grad.t(), input_data) + return input_grad, weight_grad diff --git a/debug/accuracy_tools/api_accuracy_checker/bench_functions/matmul_backward.py b/debug/accuracy_tools/api_accuracy_checker/bench_functions/matmul_backward.py new file mode 100644 index 000000000..c76850905 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/bench_functions/matmul_backward.py @@ -0,0 +1,51 @@ +import torch +from api_accuracy_checker.common.function_factory import npu_custom_grad_functions + + +@npu_custom_grad_functions +def matmul_backward(grad, self, other, mask): + grad_self, grad_other = None, None + dim_self = self.dim() + dim_other = other.dim() + + size_grad = list(grad.size()) + size_self = list(self.size()) + size_other = list(other.size()) + if dim_self == 1 and dim_other == 1: + grad_self = other.mul(grad) if mask[0] else grad_self + grad_other = self.mul(grad) if mask[1] else grad_other + elif dim_self == 2 and dim_other == 1: + grad_self = grad.unsqueeze(1).mm(other.unsqueeze(0)) if mask[0] else grad_self + grad_other = self.transpose(-1, -2).mm(grad.unsqueeze(1)).squeeze_(1) if mask[1] else grad_other + elif dim_self == 1 and dim_other == 2: + grad_self = grad.unsqueeze(0).mm(other.transpose(-1, -2)).squeeze_(0) if mask[0] else grad_self + grad_other = self.unsqueeze(1).mm(grad.unsqueeze(0)) if mask[1] else grad_other + elif dim_self >= 3 and (dim_other == 1 or dim_other == 2): + view_size = 1 if dim_other == 1 else size_grad[-1] + unfolded_grad = (grad.unsqueeze(-1) if dim_other == 1 else grad).contiguous().view(-1, view_size) + if mask[0]: + grad_self = unfolded_grad.mm(other.unsqueeze(0) if dim_other == 1 else other.transpose(-1, -2)) \ + .view(size_self) + print(f'size_self: {size_self}') + if mask[1]: + unfolded_self = self.contiguous().view([-1, size_self[-1]]) + grad_other = unfolded_self.transpose(-1, -2).mm(unfolded_grad).view(size_other) + elif (dim_self == 1 or dim_self == 2) and dim_other >= 3: + view_size = 1 if dim_self == 1 else size_grad[-2] + unfolded_grad_T = grad.view([-1, view_size]) \ + if dim_self == 1 else grad.transpose(-1, -2).contiguous().view([-1, view_size]) + if mask[0]: + # create a 2D-matrix from other + unfolded_other_T = \ + other.transpose(-1, -2).contiguous().view([-1, size_other[-2]]).transpose(-1, -2) + grad_self = unfolded_other_T.mm(unfolded_grad_T).transpose(-1, -2).view(size_self) + if mask[1]: + size_other_T = size_other[:-2] + size_other_T.extend(size_other[::-1][:2]) + grad_other = \ + unfolded_grad_T.mm(self.unsqueeze(0) if dim_self == 1 else self).view(size_other_T).transpose(-1, -2) + else: + grad_self = torch.matmul(grad, other.transpose(-1, -2)) if mask[0] else grad_self + grad_other = torch.matmul(self.transpose(-1, -2), grad) if mask[1] else grad_other + + return grad_self, grad_other diff --git a/debug/accuracy_tools/api_accuracy_checker/bench_functions/rms_norm.py b/debug/accuracy_tools/api_accuracy_checker/bench_functions/rms_norm.py new file mode 100644 index 000000000..bdf7ea616 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/bench_functions/rms_norm.py @@ -0,0 +1,18 @@ +import torch +from api_accuracy_checker.common.function_factory import npu_custom_functions, npu_custom_grad_functions + + +@npu_custom_functions +def npu_rms_norm(x, gamma, eps=1e-5): + rstd = torch.rsqrt(torch.mean(torch.pow(x, 2), axis=-1, keepdim=True) + eps) + res = x * rstd * gamma + return res, rstd + + +@npu_custom_grad_functions +def npu_rms_norm_backward(grad, x, gamma, rstd): + mean_gy = (grad * x * gamma * rstd).mean(dim=-1, keepdim=True) + grad_x = (grad * gamma - x * rstd * mean_gy) * rstd + grad_gamma = x * grad * rstd + return grad_x, grad_gamma + diff --git a/debug/accuracy_tools/api_accuracy_checker/bench_functions/rotary_mul.py b/debug/accuracy_tools/api_accuracy_checker/bench_functions/rotary_mul.py new file mode 100644 index 000000000..cad5459c7 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/bench_functions/rotary_mul.py @@ -0,0 +1,52 @@ +import torch +from api_accuracy_checker.common.function_factory import npu_custom_functions, npu_custom_grad_functions + + +@npu_custom_functions +def npu_rotary_mul(x, r1, r2): + x1, x2 = torch.chunk(x, 2, -1) + x_new = torch.cat((-x2, x1), dim=-1) + output = r1 * x + r2 * x_new + return output + + +@npu_custom_grad_functions +def npu_rotary_mul_backward(dy_tensor, x, r1, r2): + # golden + x1, x2 = torch.chunk(x, 2, -1) + x_new = torch.cat((-x2, x1), dim=-1) + golden_tensor = r1 * x + r2 * x_new + golden_tensor.backward(dy_tensor) + r1_shape = r1.shape + r1_grad = torch.zeros(r1_shape).type(torch.float32) + r2_grad = torch.zeros(r1_shape).type(torch.float32) + x1, x2 = torch.chunk(x.float(), 2, -1) + x_new2 = torch.cat((-x2, x1), dim=-1) + x_shape = x.shape + h = x.float() + grad = dy_tensor.float() + condition_1 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and + ((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and + (r1_shape[1] == x_shape[1]) and (r1_shape[3] == x_shape[3])) + condition_2 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and + ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and + (r1_shape[2] == x_shape[2]) and (r1_shape[3] == x_shape[3])) + condition_3 = (((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and + ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and + (r1_shape[0] == x_shape[0]) and (r1_shape[3] == x_shape[3])) + if condition_1: + for i in range(x_shape[0]): + for j in range(x_shape[2]): + r2_grad[0, :, 0, :] += (x_new2[i, :, j, :] * grad[i, :, j, :]) + r1_grad[0, :, 0, :] += (h[i, :, j, :] * grad[i, :, j, :]) + elif condition_2: + for i in range(x_shape[0]): + for j in range(x_shape[1]): + r2_grad[0, 0, :, :] += (x_new2[i, j, :, :] * grad[i, j, :, :]) + r1_grad[0, 0, :, :] += (h[i, j, :, :] * grad[i, j, :, :]) + elif condition_3: + for i in range(x_shape[1]): + for j in range(x_shape[2]): + r2_grad[:, 0, 0, :] += (x_new2[:, i, j, :] * grad[:, i, j, :]) + r1_grad[:, 0, 0, :] += (h[:, i, j, :] * grad[:, i, j, :]) + return x.grad, r1_grad, r2_grad diff --git a/debug/accuracy_tools/api_accuracy_checker/bench_functions/scaled_mask_softmax.py b/debug/accuracy_tools/api_accuracy_checker/bench_functions/scaled_mask_softmax.py new file mode 100644 index 000000000..8e99ab200 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/bench_functions/scaled_mask_softmax.py @@ -0,0 +1,29 @@ +import torch +from api_accuracy_checker.common.function_factory import npu_custom_functions, npu_custom_grad_functions + + +@npu_custom_functions +def npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask): + if fixed_triu_mask: + mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device) + dtype = x.dtype + x = (x * scale).masked_fill(mask, value=-10000) + x = x - torch.max(x, dim=-1, keepdims=True)[0] + x = torch.exp(x.float()) + y = torch.div(x, torch.sum(x, dim=-1, keepdims=True)) + return y.to(dtype) + + +@npu_custom_grad_functions +def npu_scaled_masked_softmax_backward(y_grad, y, mask, scale, fixed_triu_mask): + if fixed_triu_mask: + mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device) + dtype = y_grad.dtype + y_grad = y_grad.float() + y = y.float() + x_grad = y_grad * y + x_grad = y_grad - torch.sum(x_grad, dim=-1, keepdims=True) + x_grad = x_grad * y + x_grad = x_grad * scale + x_grad = x_grad.masked_fill(mask, value=0) + return x_grad.to(dtype) diff --git a/debug/accuracy_tools/api_accuracy_checker/bench_functions/swiglu.py b/debug/accuracy_tools/api_accuracy_checker/bench_functions/swiglu.py new file mode 100644 index 000000000..6685b5f47 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/bench_functions/swiglu.py @@ -0,0 +1,58 @@ +import torch +from api_accuracy_checker.common.function_factory import npu_custom_functions, npu_custom_grad_functions + + +@npu_custom_functions +def npu_swiglu(x, dim=-1): + tensor_dtype = x.dtype + + inTensors = torch.chunk(x, 2, dim=dim) + if tensor_dtype == torch.float32: + tensor_scalar = torch.sigmoid(torch.mul(inTensors[0], 1.0)) + output_data = torch.mul(torch.mul(tensor_scalar, inTensors[0]), inTensors[1]) + else: + tensor_self_float = inTensors[0].type(torch.float) + tensor_other_float = inTensors[1].type(torch.float) + tensor_out_float = torch.nn.functional.silu(tensor_self_float).type(tensor_dtype).type( + torch.float32) * tensor_other_float + output_data = tensor_out_float.type(tensor_dtype) + return output_data + + +@npu_custom_grad_functions +def npu_swiglu_backward(grad, x, dim=-1): + tensor_dtype = grad.dtype + in_tensors = torch.chunk(x, 2, dim=dim) + tensor_grad_out = grad + + if tensor_dtype == torch.float16: + tensor_out1 = torch.mul( + torch.mul(in_tensors[1].type(torch.float32), swish_grad(1, in_tensors[0].type(torch.float32))), + tensor_grad_out.type(torch.float32)).type(torch.float16) + tensor_out2 = torch.mul(tensor_grad_out.type(torch.float32), + swish(1, in_tensors[0].type(torch.float32))).type(torch.float16) + output = torch.cat((tensor_out1, tensor_out2), dim) + elif tensor_dtype == torch.bfloat16: + tensor_self_float = in_tensors[0].type(torch.float) + tensor_other_float = in_tensors[1].type(torch.float) + tensor_gradout_float = tensor_grad_out.type(torch.float) + + tensor_out1 = torch.mul(tensor_gradout_float, swish_grad(1.0, tensor_self_float)).type(torch.bfloat16).type( + torch.float32) * tensor_other_float + tensor_out2 = swish(1.0, tensor_self_float).type(torch.bfloat16).type(torch.float32) * tensor_gradout_float + tensor_out_float = torch.cat((tensor_out1, tensor_out2), dim=dim) + output = tensor_out_float.type(torch.bfloat16) + else: + tensor_out1 = torch.mul(torch.mul(in_tensors[1], swish_grad(1.0, in_tensors[0])), tensor_grad_out) + tensor_out2 = torch.mul(tensor_grad_out, swish(1.0, in_tensors[0])) + output = torch.cat((tensor_out1, tensor_out2), dim) + return output + + +def swish_grad(beta, x): + return torch.sigmoid(beta * x) + x * (1 - torch.sigmoid(beta * x)) * torch.sigmoid(beta * x) * beta + + +def swish(beta, x): + return x * torch.sigmoid(beta * x) + diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index 57f59b078..8001257e8 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -23,7 +23,12 @@ class Config: 'white_list': list, 'error_data_path': str, 'jit_compile': bool, - 'precision': int + 'precision': int, + 'is_online': bool, + 'is_golden': bool, + 'host': str, + 'port': int, + 'rank_list': list } if key not in validators: raise ValueError(f"{key} must be one of {validators.keys()}") @@ -56,13 +61,20 @@ class Config: def __str__(self): return '\n'.join(f"{key}={value}" for key, value in self.config.items()) - def update_config(self, dump_path=None, real_data=None, target_iter=None, white_list=None, enable_dataloader=None): + def update_config(self, dump_path=None, real_data=None, target_iter=None, white_list=None, enable_dataloader=None, + is_online=None, port=None, host=None, rank_list=None): args = { - "dump_path": dump_path if dump_path else self.config.get("dump_path", './'), - "real_data": real_data if real_data else self.config.get("real_data", False), - "target_iter": target_iter if target_iter else self.config.get("target_iter", [1]), - "white_list": white_list if white_list else self.config.get("white_list", []), - "enable_dataloader": enable_dataloader if enable_dataloader else self.config.get("enable_dataloader", False) + "dump_path": dump_path if dump_path is not None else self.config.get("dump_path", './'), + "real_data": real_data if real_data is not None else self.config.get("real_data", False), + "target_iter": target_iter if target_iter is not None else self.config.get("target_iter", [1]), + "white_list": white_list if white_list is not None else self.config.get("white_list", []), + "enable_dataloader": enable_dataloader + if enable_dataloader is not None else self.config.get("enable_dataloader", False), + "is_online": is_online if is_online is not None else self.config.get("is_online", False), + "is_golden": False if is_online and host is not None else True, + "host": host if host is not None else self.config.get("host", "127.0.0.1"), + "port": port if port is not None else self.config.get("port", 30001), + "rank_list": rank_list if rank_list is not None else self.config.get("rank_list", [0]) } for key, value in args.items(): if key in self.config: diff --git a/debug/accuracy_tools/api_accuracy_checker/common/function_factory.py b/debug/accuracy_tools/api_accuracy_checker/common/function_factory.py new file mode 100644 index 000000000..fb646de06 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/common/function_factory.py @@ -0,0 +1,47 @@ +class Register(dict): + def __init__(self, *args, **kwargs): + super(Register, self).__init__(*args, **kwargs) + self._dict = {} + + def register(self, target): + + def add_register_item(key, value): + if key in self._dict: + print(f"warning: {value.__name__} has been registered before, so we will overriden it.") + self[key] = value + return value + + if callable(target): + return add_register_item(target.__name__, target) + else: + raise Exception(f"The func {target} is not callable.") + + def __call__(self, target): + return self.register(target) + + def __setitem__(self, key, value): + self._dict[key] = value + + def __getitem__(self, key): + return self._dict[key] + + def __contains__(self, key): + return key in self._dict + + def __str__(self): + return str(self._dict) + + def keys(self): + return self._dict.keys() + + def values(self): + return self._dict.values() + + def items(self): + return self._dict.items() + + +npu_custom_functions = Register() +npu_custom_grad_functions = Register() + +from api_accuracy_checker.bench_functions import * diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index f6f4d26f5..1113ccd76 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -24,10 +24,12 @@ import subprocess import sys import time import csv +import logging from datetime import datetime, timezone import numpy as np import torch +import torch.distributed as dist try: import torch_npu @@ -649,3 +651,42 @@ def get_full_data_path(data_path, real_data_path): return data_path full_data_path = os.path.join(real_data_path, data_path) return os.path.realpath(full_data_path) + + +def get_tensor_rank(in_feat, out_feat): + if dist.is_initialized(): + return dist.get_rank() + + def get_tensor_rank_single(x): + if isinstance(x, (list, tuple)): + if len(x) > 0: + return get_tensor_rank_single(x[0]) + return None + elif isinstance(x, torch.Tensor): + device = x.device + if device.type == 'cpu': + return None + else: + return device.index + return None + + in_rank = get_tensor_rank_single(in_feat) + if in_rank is not None: + return in_rank + out_rank = get_tensor_rank_single(out_feat) + if out_rank is not None: + return out_rank + return None + + +def _create_logger(level=logging.INFO): + logger_ = logging.getLogger() + logger_.setLevel(level) + ch = logging.StreamHandler() + ch.setLevel(level) + logger_.addHandler(ch) + return logger_ + + +log_level = logging.DEBUG if os.environ.get("API_ACCUCARY_CHECK_LOG_LEVEL") == 1 else logging.INFO +logger = _create_logger(log_level) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index db549ec1d..a4ded976a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -1,6 +1,8 @@ # 进行比对及结果展示 import os import csv +from collections import namedtuple + import torch import numpy as np from rich.table import Table @@ -18,6 +20,10 @@ from api_accuracy_checker.common.config import msCheckerConfig from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen +ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status', + 'fwd_compare_alg_results', 'bwd_compare_alg_results', 'rank']) + + class Comparator: # consts for result csv COLUMN_API_NAME = "API name" @@ -26,9 +32,17 @@ class Comparator: COLUMN_STACK_INFO = "Traceback callstack info" def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None): - self.save_path = result_csv_path - self.detail_save_path = details_csv_path - if not is_continue_run_ut and not os.path.exists(self.save_path) and not os.path.exists(self.detail_save_path): + self.save_path_str = result_csv_path + self.detail_save_path_str = details_csv_path + self.save_path_list = [result_csv_path] + self.detail_save_path_list = [details_csv_path] + if msCheckerConfig.is_online: + self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv") + self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv") + self.save_path_list = [self.save_path_str.format(rank) for rank in msCheckerConfig.rank_list] + self.detail_save_path_list = [self.detail_save_path_str.format(rank) for rank in msCheckerConfig.rank_list] + + if not is_continue_run_ut: self.write_csv_title() if stack_info_json_path: self.stack_info = get_json_contents(stack_info_json_path) @@ -36,12 +50,18 @@ class Comparator: self.stack_info = None self.test_result_cnt = { - "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0, - "total_num": 0, "forward_or_backward_fail_num": 0 + "success_num": 0, "warning_num": 0, "error_num": 0, + "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, + "total_num": 0, "total_skip_num": 0 } + @staticmethod + def get_path_from_rank(rank, path_list, path_pattern): + return path_list[-1] if len(path_list) == 1 else path_pattern.format(rank) + def print_pretest_result(self): - self.get_statistics_from_result_csv() + for save_path in self.save_path_list: + self.get_statistics_from_result_csv(save_path) total_tests = self.test_result_cnt.get("total_num", 0) if total_tests != 0: passing_rate = "{:.2%}".format(self.test_result_cnt.get("success_num", 0) / total_tests) @@ -74,17 +94,12 @@ class Comparator: console.print(table_total) console.print(table_detail) - def get_statistics_from_result_csv(self): + def get_statistics_from_result_csv(self, save_path): checklist = [CompareConst.PASS, CompareConst.ERROR, CompareConst.WARNING, CompareConst.SPACE, CompareConst.SKIP, "skip"] - self.test_result_cnt = { - "success_num": 0, "warning_num": 0, "error_num": 0, - "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, - "total_num": 0, "total_skip_num": 0 - } - with FileOpen(self.save_path, 'r') as file: + with FileOpen(save_path, 'r') as file: reader = csv.reader(file) result_csv_rows = [row for row in reader] - result_csv_name = os.path.basename(self.save_path) + result_csv_name = os.path.basename(save_path) for item in result_csv_rows[1:]: if not isinstance(item, list) or len(item) < 3: raise ValueError("The number of columns in %s is incorrect" % result_csv_name) @@ -115,9 +130,11 @@ class Comparator: def write_csv_title(self): summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS, "Message"]] - write_csv(summary_test_rows, self.save_path) - - write_csv(DETAIL_TEST_ROWS, self.detail_save_path) + for save_path, detail_save_path in zip(self.save_path_list, self.detail_save_path_list): + if not os.path.exists(save_path): + write_csv(summary_test_rows, save_path) + if not os.path.exists(detail_save_path): + write_csv(DETAIL_TEST_ROWS, detail_save_path) def write_summary_csv(self, test_result): test_rows = [] @@ -132,7 +149,8 @@ class Comparator: stack_info = "\n".join(self.stack_info[name]) df_row.append(stack_info) test_rows.append(df_row) - write_csv(test_rows, self.save_path) + save_path = self.get_path_from_rank(test_result[-1], self.save_path_list, self.save_path_str) + write_csv(test_rows, save_path) def write_detail_csv(self, test_result): test_rows = [] @@ -153,22 +171,25 @@ class Comparator: if isinstance(item, float) else item for item in test_subject] test_rows.append([subject] + list(test_subject)) - write_csv(test_rows, self.detail_save_path) + detail_save_path = self.get_path_from_rank(test_result[-1], + self.detail_save_path_list, + self.detail_save_path_str) + write_csv(test_rows, detail_save_path) - def record_results(self, *args): + def record_results(self, args): self.write_summary_csv(args) self.write_detail_csv(args) def compare_output(self, full_api_name, data_info): _, api_name, _ = full_api_name.split("*") - bench_output = data_info.bench_output - device_output = data_info.device_output - bench_grad = data_info.bench_grad - device_grad = data_info.device_grad + bench_output, device_output = data_info.bench_out, data_info.device_out + bench_grad, device_grad = data_info.bench_grad_out, data_info.device_grad_out backward_message = data_info.backward_message compare_func = self._compare_dropout if "dropout" in full_api_name else self._compare_core_wrapper + # forward result compare fwd_success_status, fwd_compare_alg_results = compare_func(api_name, bench_output, device_output) - if not (bench_grad and device_grad): + # backward result compare + if bench_grad is None or device_grad is None: bwd_success_status, bwd_compare_alg_results = (CompareConst.SPACE, []) else: if "dropout" in full_api_name: @@ -177,10 +198,17 @@ class Comparator: bwd_success_status, bwd_compare_alg_results = compare_func(api_name, bench_grad, device_grad) if backward_message: backward_column = CompareColumn() - bwd_compare_alg_results = backward_column.to_column_value(CompareConst.SKIP, backward_message) - self.record_results(full_api_name, fwd_success_status, CompareConst.SKIP, fwd_compare_alg_results, [bwd_compare_alg_results]) + bwd_compare_alg_results = [backward_column.to_column_value(CompareConst.SKIP, backward_message)] else: - self.record_results(full_api_name, fwd_success_status, bwd_success_status if bwd_compare_alg_results is not None else CompareConst.SPACE, fwd_compare_alg_results, bwd_compare_alg_results) + bwd_success_status = bwd_success_status if bwd_compare_alg_results is not None else CompareConst.SPACE + + result_info = ResultInfo(full_api_name, + fwd_success_status, + bwd_success_status, + fwd_compare_alg_results, + bwd_compare_alg_results, + data_info.rank) + self.record_results(result_info) return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \ or bwd_success_status == CompareConst.SPACE diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index a6e70c57e..24cd96997 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -6,4 +6,9 @@ white_list: [] error_data_path: './' jit_compile: True precision: 14 +is_online: False +is_golden: True +host: "127.0.0.1" +port: 30001 +rank_list: [0] \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dispatch.py b/debug/accuracy_tools/api_accuracy_checker/dump/dispatch.py new file mode 100644 index 000000000..609e7e59c --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dispatch.py @@ -0,0 +1,86 @@ +import os +from functools import wraps + +import yaml +import torch +from torch.utils._python_dispatch import TorchDispatchMode +from api_accuracy_checker.tensor_transport_layer.attl import ApiData +from api_accuracy_checker.common.utils import get_tensor_rank +from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.dump.dump import DumpUtil +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen + + +def singleton(cls): + _instance = {} + + def inner(): + if cls not in _instance: + _instance[cls] = cls() + return _instance[cls] + return inner + + +@singleton +class Counter: + def __init__(self) -> None: + self.index_dict = {} + + +counter = Counter() +yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml") +with FileOpen(yaml_path, 'r') as f: + yaml_file = yaml.safe_load(f) + + +class AccuracyCheckerDispatch(TorchDispatchMode): + def __init__(self): + super(AccuracyCheckerDispatch, self).__init__() + self.attl = DumpUtil.attl + self.counter = counter + self.aten_ops_blacklist = [] + self.npu_adjust_autogard = [] + self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist') + self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard') + + def enable_autogard(self, aten_api): + if aten_api in self.npu_adjust_autogard: + torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False) + + def __torch_dispatch__(self, func, types, args=None, kwargs=None): + func_name_split_list = func.__name__.split(".") + aten_api = func_name_split_list[0] + self.enable_autogard(aten_api) + if aten_api in self.aten_ops_blacklist: + npu_out = func(*args, **kwargs) + return npu_out + + res = func(*args, **kwargs) + cur_rank = get_tensor_rank(args, res) + if cur_rank not in DumpUtil.rank_list: + return res + cur_api_number = self.counter.index_dict.setdefault(aten_api, 0) + api_name = f'Aten*{aten_api}*{cur_api_number}' + api_data = ApiData(api_name, args, kwargs, res, DumpUtil.call_num, cur_rank) + self.attl.send(api_data) + self.counter.index_dict[aten_api] += 1 + + return res + + +def dispatch4data(func): + @wraps(func) + def wrapper(*args, **kwargs): + if not DumpUtil.get_dump_switch() or DumpUtil.phase not in ("backward", "all") or \ + not msCheckerConfig.is_online: + return func(*args, **kwargs) + DumpUtil.set_dump_switch("OFF") + with AccuracyCheckerDispatch(): + res = func(*args, **kwargs) + DumpUtil.set_dump_switch("ON") + return res + + return wrapper + + +torch.autograd.backward = dispatch4data(torch.autograd.backward) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index d8b317aa2..535b912dd 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -14,13 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. """ +import os +import time + +import torch.distributed as dist from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo from api_accuracy_checker.dump.info_dump import write_api_info_json, initialize_output_json -from api_accuracy_checker.common.utils import print_error_log, CompareException, print_info_log +from api_accuracy_checker.common.utils import print_error_log, CompareException, print_info_log, \ + get_tensor_rank, logger from api_accuracy_checker.hook_module.register_hook import initialize_hook from api_accuracy_checker.common.config import msCheckerConfig +if msCheckerConfig.is_online: + from api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, ApiData + def set_dump_switch(switch): if switch not in ["ON", "OFF"]: @@ -35,8 +43,8 @@ def set_dump_switch(switch): def check_dataloader_status(): if msCheckerConfig.enable_dataloader: error_info = ("If you want to use this function, set enable_dataloader " - "in the accuracy_tools/api_accuracy_check/config.yaml " - "to False first") + "in the accuracy_tools/api_accuracy_check/config.yaml " + "to False first") raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info) @@ -59,6 +67,15 @@ def step(): class DumpUtil(object): dump_switch = None call_num = 0 + phase = "all" + rank_list = msCheckerConfig.rank_list + attl_config = None + attl = None + if msCheckerConfig.is_online: + attl_config = ATTLConfig(msCheckerConfig.is_golden, connect_ip=msCheckerConfig.host, + connect_port=msCheckerConfig.port) + need_dump = dist.get_rank() in msCheckerConfig.rank_list if dist.is_initialized() else True + attl = ATTL('gpu' if msCheckerConfig.is_golden else 'npu', attl_config, need_dump=need_dump) @staticmethod def set_dump_switch(switch): @@ -73,6 +90,14 @@ class DumpUtil(object): if DumpUtil.call_num in msCheckerConfig.target_iter: set_dump_switch("ON") elif DumpUtil.call_num > max(msCheckerConfig.target_iter): + if msCheckerConfig.is_online: + if DumpUtil.attl.socket_manager is not None: + logger.debug(f"进程{os.getpid()} 已完成,准备发送STOP信号") + DumpUtil.attl.socket_manager.send_stop_signal() + logger.debug(f"has stop rank_{dist.get_rank()} process") + else: + while True: + time.sleep(2) raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) else: set_dump_switch("OFF") @@ -99,11 +124,27 @@ def pretest_info_dump(name, out_feat, module, phase): write_api_info_json(api_info) +def pretest_real_data_transport(name, out_feat, module, phase): + if not DumpUtil.get_dump_switch(): + return + if phase == DumpConst.forward and (DumpUtil.phase == "all" or DumpUtil.phase == phase): + cur_rank = get_tensor_rank(module.input_args, out_feat) + if cur_rank not in DumpUtil.rank_list: + return + api_data = ApiData(name, module.input_args, module.input_kwargs, out_feat, DumpUtil.call_num, cur_rank) + print_info_log(f"tools is dumping api: {api_data.name}, rank: {cur_rank}") + DumpUtil.attl.send(api_data) + + def pretest_hook(name, phase): def pretest_info_dump_hook(module, in_feat, out_feat): - pretest_info_dump(name, out_feat, module, phase) + if msCheckerConfig.is_online: + pretest_real_data_transport(name, out_feat, module, phase) + else: + pretest_info_dump(name, out_feat, module, phase) if hasattr(module, "input_args"): del module.input_args if hasattr(module, "input_kwargs"): del module.input_kwargs + return pretest_info_dump_hook diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/torch_ops_config.yaml b/debug/accuracy_tools/api_accuracy_checker/dump/torch_ops_config.yaml new file mode 100644 index 000000000..089aa84ac --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/dump/torch_ops_config.yaml @@ -0,0 +1,60 @@ +aten_ops_blacklist: + - _cudnn_rnn + - _local_scalar_dense + - _pin_memory + - _to_copy + - _unsafe_view + - clone + - contiguous + - copy_ + - cudnn_batch_norm + - cudnn_batch_norm_backward + - detach + - empty + - index_put_ + - lift_fresh + - max_pool2d_with_indices_backward # shape unmatch + - native_batch_norm_backward + - new_empty + - new_empty_strided + - new_full + - new_ones + - new_zeros + - ones + - ones_like + - permute + - rand + - rand_like + - randint + - randint_like + - randn + - randn_like + - randperm + - scalar_tensor + - select + - to + - transpose + - unbind + - view + - zero + - zero_ + - zeros + - zeros_like + - _record_function_enter_new + - _record_function_exit + - broadcast_ + - allreduce_ + - npu_clear_float_status + - npu_format_cast + - npu_dtype_cast + - _allgather_base_ + - _reduce_scatter_base_ + - is_same_size + +npu_adjust_autogard: + - adaptive_avg_pool2d + - batch_norm + - log_softmax + - nll_loss + - to + \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/register_hook.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/register_hook.py index b355e029b..076366e93 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/register_hook.py +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/register_hook.py @@ -18,6 +18,14 @@ import torch from api_accuracy_checker.hook_module import wrap_torch, wrap_functional, wrap_tensor +try: + import torch_npu +except ImportError: + is_gpu = True +else: + is_gpu = False + from api_accuracy_checker.hook_module import wrap_npu_custom + def initialize_hook(hook): wrap_tensor.wrap_tensor_ops_and_bind(hook) @@ -35,3 +43,9 @@ def initialize_hook(hook): if attr_name.startswith("wrap_"): setattr(torch.nn.functional, attr_name[5:], getattr(wrap_functional.HOOKFunctionalOP, attr_name)) + if not is_gpu: + wrap_npu_custom.wrap_npu_ops_and_bind(hook) + for attr_name in dir(wrap_npu_custom.HOOKNpuOP): + if attr_name.startswith("wrap_"): + setattr(torch_npu, attr_name[5:], getattr(wrap_npu_custom.HOOKNpuOP, attr_name)) + diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/api_accuracy_checker/hook_module/support_wrap_ops.yaml index c7ed0a1f8..50293ff45 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/support_wrap_ops.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/support_wrap_ops.yaml @@ -997,3 +997,746 @@ torch: - vstack - where - xlogy_ + +aten: + - signbit + - logical_not_ + - _foreach_copy_ + - clamp + - hardswish_ + - arcsin_ + - logsumexp + - native_group_norm + - special_i1e + - bitwise_and + - new_full + - fft_ihfft + - _adaptive_avg_pool2d + - scatter_add + - abs + - selu + - exponential + - silu + - _native_batch_norm_legit_functional + - special_hermite_polynomial_h + - tanh_ + - log_sigmoid_forward + - _fft_c2c + - heaviside_ + - sigmoid_backward + - zeros_like + - as_strided_scatter + - trace + - _assert_async + - avg_pool2d_backward + - exp2 + - binary_cross_entropy_backward + - geometric + - fft_ihfftn + - smooth_l1_loss + - multiply + - __lshift__ + - binary_cross_entropy_with_logits + - _embedding_bag + - arange + - linalg_qr + - _embedding_bag_forward_only + - _unsafe_view + - remainder + - cholesky_inverse + - sub_ + - zero + - fix + - xlogy + - __doc__ + - rsqrt_ + - cummin + - __xor__ + - eye + - _fused_adam + - ceil + - nll_loss2d_backward + - replication_pad3d_backward + - fill_ + - logaddexp2 + - _thnn_fused_lstm_cell_backward_impl + - native_dropout + - fft_ifft + - expand + - _cdist_backward + - avg_pool3d_backward + - round_ + - topk + - max_unpool3d + - xlogy_ + - reflection_pad2d_backward + - addcdiv_ + - relu6 + - multilabel_margin_loss_forward + - prelu + - logaddexp + - _cholesky_solve_helper + - _foreach_addcdiv + - arctan_ + - fft_irfftn + - logical_or + - bitwise_or_ + - hardtanh_backward + - uniform + - less_equal + - _foreach_sub + - linalg_cholesky_ex + - hardswish + - fft_fft2 + - sign + - min + - norm + - asin + - addcmul_ + - stft + - col2im + - special_chebyshev_polynomial_u + - adaptive_max_pool3d + - __ilshift__ + - _resize_output + - gather + - lu_unpack + - native_batch_norm_backward + - sigmoid + - sqrt + - new_empty_strided + - _foreach_lerp_ + - mean + - scatter_add_ + - _fft_c2r + - rand_like + - true_divide_ + - gcd_ + - multinomial + - permute + - index_put_ + - arcsinh_ + - log1p_ + - index_add + - atan + - glu_backward + - searchsorted + - fill + - _unsafe_index + - index_reduce_ + - replication_pad2d + - expm1_ + - hardsigmoid + - addmm + - fft_fftn + - fft_ifftshift + - special_modified_bessel_k1 + - fft_rfft + - ge + - _adaptive_avg_pool2d_backward + - argmin + - linalg_lu_factor_ex + - atanh_ + - addmv + - _foreach_sqrt_ + - huber_loss_backward + - empty_like + - softshrink + - subtract_ + - bitwise_left_shift_ + - special_modified_bessel_i0 + - _nested_tensor_from_tensor_list + - slice_backward + - special_modified_bessel_i1 + - special_chebyshev_polynomial_t + - conj_physical + - _cdist_forward + - margin_ranking_loss + - max_pool3d_with_indices_backward + - _foreach_reciprocal_ + - lcm + - transpose_ + - cudnn_batch_norm_backward + - reciprocal + - copysign_ + - _foreach_pow + - rad2deg + - _foreach_sqrt + - negative + - replication_pad3d + - atanh + - _linalg_eigh + - igamma_ + - special_i0e + - linalg_ldl_factor_ex + - special_ndtri + - logit + - diagonal_copy + - triu + - silu_ + - polygamma + - square_ + - nextafter_ + - special_scaled_modified_bessel_k0 + - bitwise_not + - var + - mkldnn_rnn_layer_backward + - upsample_bilinear2d + - arctan2 + - clone + - arcsin + - new_ones + - soft_margin_loss + - nan_to_num + - huber_loss + - linalg_lu_solve + - elu_backward + - acosh + - __ior__ + - _unsafe_index_put + - __or__ + - _linalg_slogdet + - arcsinh + - select_scatter + - less_ + - reflection_pad1d + - istft + - reflection_pad2d + - diagonal_backward + - special_entr + - _softmax_backward_data + - randn + - celu + - embedding + - igammac_ + - new_zeros + - native_layer_norm_backward + - nonzero_static + - diagonal_scatter + - grid_sampler_2d + - smooth_l1_loss_backward + - _to_copy + - fft_irfft2 + - relu_ + - fmod + - log1p + - i0 + - mse_loss_backward + - copy + - special_laguerre_polynomial_l + - addmv_ + - quantized_gru + - diag_embed + - acos + - fmod_ + - linalg_cross + - mvlgamma_ + - _foreach_mul + - cummax + - less_equal_ + - ne + - to + - _pdist_forward + - special_xlog1py + - digamma + - lgamma + - mv + - softplus + - special_bessel_y1 + - pin_memory + - logical_xor_ + - cat + - grid_sampler_2d_backward + - frac_ + - dropout + - unsafe_chunk + - masked_fill_ + - log + - negative_ + - _scaled_dot_product_flash_attention + - _amp_foreach_non_finite_check_and_unscale_ + - randn_like + - add + - roll + - threshold + - gcd + - asinh + - round + - t_ + - unfold_backward + - scatter_reduce + - softplus_backward + - bitwise_right_shift_ + - pdist + - select_backward + - relu + - special_bessel_j1 + - asinh_ + - pow + - fft_fftshift + - clamp_max_ + - logical_xor + - index_reduce + - _foreach_add_ + - adaptive_max_pool2d + - adaptive_max_pool3d_backward + - tan + - addbmm_ + - cosh_ + - __rshift__ + - _foreach_maximum + - fft_ifftn + - special_spherical_bessel_j0 + - split_with_sizes + - divide_ + - neg_ + - nll_loss + - _euclidean_dist + - pairwise_distance + - _adaptive_avg_pool3d + - slice + - absolute_ + - gelu_backward + - arccos + - sin + - tril_ + - triu_ + - fft_irfft + - flip + - _foreach_sign + - linalg_householder_product + - _list_to_tensor + - cumprod + - randint_like + - item + - narrow_copy + - tanh + - linalg_vector_norm + - _cudnn_rnn + - _scaled_dot_product_efficient_attention + - _reshape_alias + - _linalg_det + - constant_pad_nd + - _linalg_svd + - sinh_ + - view + - nll_loss_backward + - greater + - sqrt_ + - avg_pool3d + - arctan + - le_ + - _pdist_backward + - _adaptive_avg_pool3d_backward + - log_ + - logical_or_ + - mse_loss + - rrelu_with_noise_backward + - _native_batch_norm_legit + - log10 + - scatter_ + - atan2_ + - greater_equal + - index_select + - __iand__ + - digamma_ + - eq + - divide + - cholesky_solve + - _prelu_kernel + - fft_ifft2 + - _foreach_neg_ + - alias + - erfc_ + - not_equal + - mul + - gru + - _dir + - glu + - clip + - lt + - rsqrt + - avg_pool2d + - conj_physical_ + - quantized_lstm + - erfinv_ + - log10_ + - float_power_ + - _functional_assert_async + - hardtanh + - logical_and_ + - _resize_output_ + - clamp_min + - _functional_sym_constrain_range_for_size + - _addmm_activation + - bucketize + - _thnn_fused_lstm_cell + - zeros + - reflection_pad1d_backward + - tan_ + - bitwise_not_ + - addmm_ + - absolute + - as_strided + - special_ndtr + - gt_ + - baddbmm + - special_log_ndtr + - hardshrink + - fft_hfft + - hypot + - native_layer_norm + - _scaled_dot_product_flash_attention_backward + - floor_divide + - is_same_size + - std + - floor_divide_ + - clamp_min_ + - _foreach_sign_ + - std_mean + - tanh_backward + - _foreach_addcmul + - binary_cross_entropy + - threshold_backward + - deg2rad_ + - masked_fill + - linspace + - reflection_pad3d + - mish + - index_copy + - scatter_reduce_ + - _sparse_coo_tensor_with_dims_and_tensors + - __loader__ + - _foreach_div_ + - cosh + - _foreach_maximum_ + - neg + - lift_fresh + - logspace + - selu_ + - leaky_relu_ + - matmul + - _foreach_sub_ + - bitwise_or + - unfold + - fmin + - convolution + - argmax + - maximum + - reflection_pad3d_backward + - fft_fft + - mode + - remainder_ + - _foreach_neg + - erf_ + - special_zeta + - index_add_ + - arccos_ + - lgamma_ + - unsqueeze_ + - gelu_ + - bmm + - _add_relu + - unfold_copy + - not_equal_ + - subtract + - true_divide + - max_pool2d_with_indices_backward + - _native_batch_norm_legit_no_training + - replication_pad1d + - name + - greater_ + - log_normal + - minimum + - alpha_dropout + - rnn_tanh + - _functional_sym_constrain_range + - sum + - _prelu_kernel_backward + - cumsum_ + - ne_ + - _linalg_solve_ex + - native_batch_norm + - igammac + - hypot_ + - exp + - leaky_relu + - new_empty + - cudnn_batch_norm + - resize_as_ + - mm + - triangular_solve + - sign_ + - clamp_max + - bitwise_right_shift + - logical_and + - special_i0 + - index_copy_ + - arctanh_ + - elu + - index + - isposinf + - linalg_solve_triangular + - logcumsumexp + - arccosh + - nan_to_num_ + - nll_loss_forward + - convolution_backward + - sub + - special_scaled_modified_bessel_k1 + - mish_ + - diagonal + - median + - tril + - sgn + - native_group_norm_backward + - stack + - take + - linalg_lu + - log2 + - hardsigmoid_ + - erfc + - max + - native_dropout_backward + - logit_ + - addr + - clip_ + - _foreach_minimum_ + - atan_ + - repeat + - cumprod_ + - bitwise_xor_ + - less + - index_put + - rrelu_with_noise + - addbmm + - special_bessel_y0 + - __and__ + - bernoulli_ + - uniform_ + - log2_ + - mul_ + - adaptive_max_pool2d_backward + - _foreach_addcmul_ + - slice_scatter + - isneginf + - pow_ + - renorm_ + - arccosh_ + - replication_pad1d_backward + - bitwise_and_ + - heaviside + - renorm + - special_modified_bessel_k0 + - le + - is_pinned + - __ixor__ + - leaky_relu_backward + - count_nonzero + - _fused_adam_ + - repeat_interleave + - upsample_bicubic2d + - rsub + - arctan2_ + - frac + - scalar_tensor + - rrelu_with_noise_ + - rot90 + - erf + - lerp_ + - expm1 + - full + - sym_constrain_range_for_size + - prod + - normal_ + - elu_ + - special_airy_ai + - nextafter + - split + - addcdiv + - fft_rfft2 + - max_pool3d_with_indices + - positive + - transpose + - mish_backward + - clamp_ + - exp_ + - _foreach_reciprocal + - linalg_matrix_exp + - unsqueeze + - upsample_nearest2d + - sinc_ + - select + - rad2deg_ + - trunc_ + - _make_dep_token + - nanmedian + - fft_hfftn + - hardtanh_ + - sym_constrain_range + - index_fill_ + - deg2rad + - rand + - sinc + - pixel_shuffle + - tril_indices + - copy_ + - _int_mm + - greater_equal_ + - celu_ + - div + - igamma + - exp2_ + - cos + - log_normal_ + - _log_softmax_backward_data + - im2col + - reciprocal_ + - amax + - broadcast_tensors + - erfinv + - __spec__ + - _fused_dropout + - special_hermite_polynomial_he + - aminmax + - rnn_relu + - meshgrid + - var_mean + - eq_ + - upsample_nearest3d + - dot + - zero_ + - floor_ + - fft_rfftn + - special_erfcx + - _foreach_div + - fft_hfft2 + - _upsample_bilinear2d_aa + - sort + - log_sigmoid_backward + - add_ + - copysign + - bernoulli + - special_bessel_j0 + - max_pool2d_with_indices + - _scaled_dot_product_efficient_attention_backward + - t + - _softmax + - arctanh + - hinge_embedding_loss + - hardswish_backward + - fmax + - multiply_ + - floor + - lstm + - i0_ + - cholesky + - where + - __irshift__ + - addcmul + - embedding_dense_backward + - sigmoid_ + - fix_ + - ormqr + - exponential_ + - __name__ + - fft_ihfft2 + - logical_not + - ones + - sgn_ + - sinh + - any + - _foreach_addcdiv_ + - asin_ + - gt + - lift + - squeeze + - grid_sampler_3d_backward + - atan2 + - _fft_r2c + - angle + - silu_backward + - acosh_ + - abs_ + - lerp + - special_i1 + - complex + - ceil_ + - _foreach_minimum + - hardsigmoid_backward + - upsample_nearest1d + - mvlgamma + - acos_ + - lt_ + - grid_sampler_3d + - max_unpool2d + - ones_like + - soft_margin_loss_backward + - _fused_moving_avg_obs_fq_helper + - isnan + - nansum + - baddbmm_ + - amin + - isinf + - bitwise_left_shift + - unsafe_split_with_sizes + - full_like + - sin_ + - bitwise_xor + - linalg_ldl_solve + - cos_ + - div_ + - polar + - randint + - trunc + - __package__ + - nll_loss2d_forward + - diag + - argsort + - _foreach_mul_ + - square + - detach + - affine_grid_generator + - _pin_memory + - geometric_ + - unbind + - randperm + - upsample_nearest2d_backward + - all + - threshold_ + - unsafe_split + - cauchy + - normal + - linalg_inv_ex + - multi_margin_loss + - cumsum + - gelu + - index_fill + - scatter + - mkldnn_rnn_layer + - ge_ + - dist + - _foreach_add + - logit_backward + - triu_indices + - lcm_ + - empty_strided + - replication_pad2d_backward + - cauchy_ + - _log_softmax + - vdot + +white_aten_ops: + - embedding_backward + +torch_npu: + - npu_apply_adam_w + - npu_confusion_transpose + - fast_gelu + - npu_layer_norm_eval + - npu_linear + - npu_rms_norm + - npu_rotary_mul + - npu_scaled_masked_softmax + - npu_swiglu + - npu_fusion_attention diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/utils.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/utils.py index 7d16ac993..6d06c9f19 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/utils.py @@ -26,4 +26,7 @@ with FileOpen(yaml_path, 'r') as f: Ops = yaml.safe_load(f) WrapFunctionalOps = Ops.get('functional') WrapTensorOps = Ops.get('tensor') - WrapTorchOps = Ops.get('torch') \ No newline at end of file + WrapTorchOps = Ops.get('torch') + WrapAtenOps = Ops.get('aten') + WrapNPUOps = Ops.get('torch_npu') + WhiteAtenOps = Ops.get('white_aten_ops') diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_aten.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_aten.py new file mode 100644 index 000000000..3700ae8a4 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_aten.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import torch + +from api_accuracy_checker.hook_module.hook_module import HOOKModule +from api_accuracy_checker.common.utils import torch_device_guard +from api_accuracy_checker.hook_module.utils import WrapAtenOps, WhiteAtenOps +from api_accuracy_checker.common.function_factory import npu_custom_grad_functions + + +aten_func = {} +for f in dir(torch.ops.aten): + aten_func[f] = getattr(torch.ops.aten, f) + + +def get_aten_ops(): + global WrapAtenOps + _all_aten_ops = dir(torch.ops.aten) + return set(WrapAtenOps) & set(_all_aten_ops) + + +class HOOKAtenOP(object): + pass + + +class AtenOPTemplate(HOOKModule): + def __init__(self, op): + self.op = op + + @torch_device_guard + def forward(self, *args, **kwargs): + if self.op in npu_custom_grad_functions: + return npu_custom_grad_functions[self.op](*args, **kwargs) + if self.op in WhiteAtenOps: + return eval(f"torch.ops.aten.{self.op}")(*args, **kwargs) + if self.op not in aten_func: + raise Exception(f"The op {self.op} is not in dir(torch.ops.aten) and support yaml.") + return aten_func[self.op](*args, **kwargs) + + +class AtenOPPacketTemplate(): + def __init__(self, opPacket, hook): + self.opPacket = opPacket + self.hook = hook + + def __getattr__(self, key): + try: + attr = getattr(self.opPacket, key) + except AttributeError as e: + raise AttributeError(f"AtenOPPacketTemplate or OpOverloadPacket does not have attribute '{key}'.") from e + if isinstance(attr, torch._ops.OpOverload): + return AtenOPTemplate(attr, self.hook) + else: + return attr + + def overloads(self): + return self.opPacket.overloads() + + @torch_device_guard + def __call__(self, *args, **kwargs): + return AtenOPTemplate(self.opPacket, self.hook)(*args, **kwargs) + + +def wrap_aten_op(op, hook): + return AtenOPPacketTemplate(op, hook) + + +def wrap_aten_ops_and_bind(hook): + _aten_ops = get_aten_ops() + for op_name in _aten_ops: + if not isinstance(aten_func.get(op_name), torch._ops.OpOverloadPacket): + continue + setattr(HOOKAtenOP, "wrap_" + str(op_name), wrap_aten_op(aten_func.get(op_name), hook)) diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_npu_custom.py new file mode 100644 index 000000000..a4a42df11 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_npu_custom.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import torch + +from api_accuracy_checker.hook_module.hook_module import HOOKModule +from api_accuracy_checker.common.utils import torch_device_guard, torch_without_guard_version +from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.hook_module.utils import WrapNPUOps +from api_accuracy_checker.common.function_factory import npu_custom_functions + +try: + import torch_npu +except ImportError: + is_gpu = True +else: + is_gpu = False + + +def get_npu_ops(): + global WrapNPUOps + if torch_without_guard_version: + _npu_ops = dir(torch.ops.npu) + else: + _npu_ops = dir(torch_npu._C._VariableFunctionsClass) + + if msCheckerConfig.white_list: + return set(WrapNPUOps) & set(_npu_ops) & set(msCheckerConfig.white_list) + else: + return set(WrapNPUOps) & set(_npu_ops) + + +class HOOKNpuOP(object): + pass + + +class NPUOPTemplate(HOOKModule): + + def __init__(self, op_name, hook, need_hook=True): + self.op_name_ = op_name + self.prefix_op_name_ = "NPU*" + str(op_name) + "*" + self.need_hook = need_hook + if need_hook: + super().__init__(hook) + + @torch_device_guard + def forward(self, *args, **kwargs): + if not self.need_hook: + if self.op_name_ not in npu_custom_functions: + raise Exception(f'There is not bench function {self.op_name_}') + return npu_custom_functions[self.op_name_](*args, **kwargs) + if torch_without_guard_version: + return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs) + else: + return getattr(torch_npu._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs) + + +def wrap_npu_op(op_name, hook): + + def npu_op_template(*args, **kwargs): + return NPUOPTemplate(op_name, hook)(*args, **kwargs) + + return npu_op_template + + +def wrap_npu_ops_and_bind(hook): + _npu_ops = get_npu_ops() + for op_name in _npu_ops: + setattr(HOOKNpuOP, "wrap_" + str(op_name), wrap_npu_op(op_name, hook)) \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 05bd4305a..e45c56732 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -26,6 +26,8 @@ from api_accuracy_checker.compare.compare_utils import CompareConst from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate from api_accuracy_checker.hook_module.wrap_functional import FunctionalOPTemplate from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate +from api_accuracy_checker.hook_module.wrap_aten import AtenOPTemplate +from api_accuracy_checker.hook_module.wrap_npu_custom import NPUOPTemplate from api_accuracy_checker.common.config import msCheckerConfig from api_accuracy_checker.dump.api_info import APIInfo from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create @@ -34,10 +36,14 @@ from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_ from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen, FileCheckConst, FileChecker, \ change_mode, check_file_suffix, check_link +if msCheckerConfig.is_online: + from api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, ApiData + from api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher + current_time = time.strftime("%Y%m%d%H%M%S") UT_ERROR_DATA_DIR = 'ut_error_data' + current_time -RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" -DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv" +RESULT_FILE_NAME = f"accuracy_checking_result_" + current_time + ".csv" +DETAILS_FILE_NAME = f"accuracy_checking_details_" + current_time + ".csv" RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path', 'save_error_data', 'is_continue_run_ut', 'real_data_path']) not_backward_list = ['repeat_interleave'] @@ -70,6 +76,12 @@ def exec_api(api_type, api_name, args, kwargs): if api_type == "Torch": torch_api = TorchOPTemplate(api_name, str, False) out = torch_api.forward(*args, **kwargs) + if api_type == "Aten": + torch_api = AtenOPTemplate(api_name) + out = torch_api.forward(*args, **kwargs) + if api_type == "NPU": + torch_api = NPUOPTemplate(api_name, None, False) + out = torch_api.forward(*args, **kwargs) return out @@ -161,10 +173,23 @@ def run_ut(config): error_data_path = os.path.abspath(os.path.join(msCheckerConfig.error_data_path, UT_ERROR_DATA_DIR)) print_info_log(f"UT task error_datas will be saved in {error_data_path}") compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut) - with FileOpen(config.result_csv_path, 'r') as file: - csv_reader = csv.reader(file) - next(csv_reader) - api_name_set = {row[0] for row in csv_reader} + if msCheckerConfig.is_online: + run_api_online(config, compare) + else: + with FileOpen(config.result_csv_path, 'r') as file: + csv_reader = csv.reader(file) + next(csv_reader) + api_name_set = {row[0] for row in csv_reader} + run_api_offline(config, compare, api_name_set) + for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list): + change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) + change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) + print_info_log(f"UT task result csv is saved in {result_csv_path}") + print_info_log(f"UT task details csv is saved in {details_csv_path}") + compare.print_pretest_result() + + +def run_api_offline(config, compare, api_name_set): for i, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)): if api_full_name in api_name_set: continue @@ -174,8 +199,7 @@ def run_ut(config): if api_name not in set(msCheckerConfig.white_list): continue data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict) - is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, - data_info) + is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info) if config.save_error_data: do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) except Exception as err: @@ -187,16 +211,40 @@ def run_ut(config): print_error_log(f"Run {api_full_name} UT Error: %s" % str(err)) err_column = CompareColumn() fwd_compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err)) - compare.record_results(api_full_name, CompareConst.SKIP, CompareConst.SKIP, [fwd_compare_alg_results], None) + result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [fwd_compare_alg_results], None, 0) + compare.record_results(result_info) finally: if is_gpu: torch.cuda.empty_cache() else: torch.npu.empty_cache() gc.collect() - change_mode(compare.save_path, FileCheckConst.DATA_FILE_AUTHORITY) - change_mode(compare.detail_save_path, FileCheckConst.DATA_FILE_AUTHORITY) - compare.print_pretest_result() + + +def run_api_online(config, compare): + attl = init_attl() + dispatcher = ConsumerDispatcher(compare=compare) + dispatcher.start(handle_func=run_torch_api_online, config=config) + while True: + api_data = attl.recv() + if api_data == 'STOP_': + continue + if api_data == 'KILL_': + time.sleep(1) + print_info_log("==========接收到STOP信号==========") + dispatcher.stop() + attl.stop_serve() + time.sleep(1) + break + if not isinstance(api_data, ApiData): + continue + api_full_name = api_data.name + + if msCheckerConfig.white_list: + [_, api_name, _] = api_full_name.split("*") + if api_name not in set(msCheckerConfig.white_list): + continue + dispatcher.update_consume_queue(api_data) def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): @@ -257,6 +305,20 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message) +def run_torch_api_online(api_full_name, api_data, backward_content): + in_fwd_data_list = [] + [api_type, api_name, _] = api_full_name.split("*") + args, kwargs, out = api_data.args, api_data.kwargs, api_data.result + in_fwd_data_list.append(args) + in_fwd_data_list.append(kwargs) + if kwargs.get("device"): + del kwargs["device"] + + device_out = exec_api(api_type, api_name, args, kwargs) + device_out = device_out.cpu() if hasattr(device_out, "cpu") else device_out + return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank) + + def get_api_info(api_info_dict, api_name, real_data_path): convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict) need_grad = True @@ -321,13 +383,18 @@ def get_validated_details_csv_path(validated_result_csv_path): return validated_details_csv_path +def init_attl(): + attl = ATTL('gpu', ATTLConfig(is_golden=True, connect_port=msCheckerConfig.port)) + return attl + + def _run_ut_parser(parser): parser.add_argument("-forward", "--forward_input_file", dest="forward_input_file", default="", type=str, - help=" The api param tool forward result file: generate from api param tool, " + help=" The api param tool forward result file: generate from api param tool, " "a json file.", - required=True) + required=False) parser.add_argument("-backward", "--backward_input_file", dest="backward_input_file", default="", type=str, - help=" The api param tool backward result file: generate from api param tool, " + help=" The api param tool backward result file: generate from api param tool, " "a json file.", required=False) parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, @@ -412,16 +479,19 @@ def run_ut_command(args): except Exception as error: print_error_log(f"Set device id failed. device id is: {args.device_id}") raise NotImplementedError from error - check_link(args.forward_input_file) - forward_file = os.path.realpath(args.forward_input_file) - check_file_suffix(forward_file, FileCheckConst.JSON_SUFFIX) + out_path = os.path.realpath(args.out_path) if args.out_path else "./" check_path_before_create(out_path) create_directory(out_path) out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) out_path = out_path_checker.common_check() save_error_data = args.save_error_data - forward_content = get_json_contents(forward_file) + forward_content = {} + if args.forward_input_file: + check_link(args.forward_input_file) + forward_file = os.path.realpath(args.forward_input_file) + check_file_suffix(forward_file, FileCheckConst.JSON_SUFFIX) + forward_content = get_json_contents(forward_file) if args.filter_api: forward_content = preprocess_forward_content(forward_content) backward_content = {} @@ -447,14 +517,16 @@ def run_ut_command(args): class UtDataInfo: - def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list, backward_message): - self.bench_grad = bench_grad - self.device_grad = device_grad - self.device_output = device_output - self.bench_output = bench_output + def __init__(self, bench_grad_out, device_grad_out, device_out, + bench_out, grad_in, in_fwd_data_list, backward_message, rank=0): + self.bench_grad_out = bench_grad_out + self.device_grad_out = device_grad_out + self.device_out = device_out + self.bench_out = bench_out self.grad_in = grad_in self.in_fwd_data_list = in_fwd_data_list self.backward_message = backward_message + self.rank = rank class UtAPIInfo(APIInfo): diff --git a/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/attl.py b/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/attl.py new file mode 100644 index 000000000..fcf93b58a --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/attl.py @@ -0,0 +1,150 @@ +import io +import time +from multiprocessing import Queue +from typing import Optional, Union, Dict, Any +from collections import namedtuple +from dataclasses import dataclass + +import torch + +from api_accuracy_checker.tensor_transport_layer.client import TCPClient +from api_accuracy_checker.tensor_transport_layer.server import TCPServer +from api_accuracy_checker.common.utils import logger + + +ApiData = namedtuple('ApiData', ['name', 'args', 'kwargs', 'result', 'step', 'rank'], + defaults=['unknown', None, None, None, 0, 0]) +BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]] + + +@dataclass +class ATTLConfig: + # net_config: dict + is_golden: bool + connect_ip: str = "127.0.0.1" + connect_port: int = 8006 + # storage_config + check_sum: bool = True + queue_size: int = 50 + + +class ATTL: + def __init__(self, session_id: str, session_config: ATTLConfig, need_dump=True) -> None: + self.session_id = session_id + self.session_config = session_config + self.logger = logger + self.socket_manager = None + self.data_queue = Queue(maxsize=50) + self.dequeue_list = [] + self.message_end = False + self.kill_progress = False + if self.session_config.is_golden: + self.socket_manager = TCPServer(self.session_config.connect_port, + self.data_queue, + self.session_config.check_sum) + self.socket_manager.start() + elif need_dump: + self.socket_manager = TCPClient(self.session_config.connect_ip, + self.session_config.connect_port, + self.session_config.check_sum) + self.socket_manager.start() + + def stop_serve(self): + if isinstance(self.socket_manager, TCPServer): + self.socket_manager.stop() + + def client_handle(self, data, rank: int = 0, step: int = 0): + self.socket_manager.add_to_sending_queue(data, rank=rank, step=step) + + def send(self, buffer: BufferType): + """ + npu major in 'send' (client) + """ + # know receiver receive and go next + if isinstance(buffer, ApiData): + buffer = move2target_device(buffer, torch.device('cpu')) + + if 'device' in buffer.kwargs: + buffer.kwargs.pop('device') + rank = buffer.rank if hasattr(buffer, "rank") else 0 + step = buffer.step if hasattr(buffer, "step") else 0 + io_buff = io.BytesIO() + torch.save(buffer, io_buff) + self.client_handle(io_buff.getvalue(), rank=rank, step=step) + + def recv(self, timeout_ms=0) -> Optional[BufferType]: + buffer = None + while buffer is None: + if timeout_ms > 0: + time.sleep(timeout_ms / 1000.0) + if buffer is None and not self.data_queue.empty(): + buffer = self.data_queue.get() + break + if buffer is None and timeout_ms > 0: # timeout is the only case we give up and return None + break + if self.message_end and self.data_queue.empty(): + buffer = b"KILL_CONFIRM" + self.kill_progress = True + break + time.sleep(0.1) # waiting outside the lock before next attempt + if buffer is None: + # this is a result of a timeout + self.logger.info(f"RECEIVE API DATA TIMED OUT") + else: + if buffer == b"STOP_": + return "STOP_" + if buffer == b"KILL_": + self.message_end = True + return "STOP_" + if buffer == b"KILL_CONFIRM": + self.kill_progress = True + return "KILL_" + buffer = io.BytesIO(buffer) + try: + buffer = torch.load(buffer, map_location="cpu") + except Exception as e: + self.logger.error("there is something error. please check it. %s", e) + if isinstance(buffer, bytes): + return None + if isinstance(buffer, str): + return buffer + + return buffer + + +def move2device_exec(obj, device): + if isinstance(obj, (tuple, list)): + data_list = [move2device_exec(val, device) for val in obj] + return data_list if isinstance(obj, list) else tuple(data_list) + if isinstance(obj, dict): + return {key: move2device_exec(val, device) for key, val in obj.items()} + elif isinstance(obj, torch.Tensor): + obj = obj.detach() + if obj.device.type != device: + obj = obj.to(device) + return obj + elif isinstance(obj, torch._C.device): + return torch.device(device) + else: + return obj + + +def move2target_device(buffer: ApiData, target_device): + # handle args + new_args = move2device_exec(buffer.args, target_device) + + # handle kwargs + new_kwargs = move2device_exec(buffer.kwargs, target_device) + + # handle result + new_results = [] + res = buffer.result[0] if isinstance(buffer.result, (tuple, list)) else buffer.result + if isinstance(res, torch.Tensor) and res.device.type != target_device: + new_results.append(res.detach().to(target_device)) + else: + new_results.append(res) + + if target_device == torch.device('cpu') or target_device == "cpu": + return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results[0], buffer.step, buffer.rank) + else: + return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank) diff --git a/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/client.py b/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/client.py new file mode 100644 index 000000000..da73657dc --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/client.py @@ -0,0 +1,309 @@ +import hashlib +import io +import struct +import time +import os +from queue import Queue +from threading import Thread +from typing import Union +from twisted.internet import reactor, protocol, endpoints +from twisted.protocols.basic import FileSender +from api_accuracy_checker.common.utils import logger, print_info_log + + +class TCPDataItem: + def __init__(self, data, + sequence_number: int, + rank: int = 0, + step: int = 0): + self.raw_data = data + self.sequence_number = sequence_number + self.rank = rank + self.step = step + self.retry_times = 0 + self.pending_time = 0 + self.busy_time = 0 + + +class TCPClient: + MAX_SENDING_QUEUE_SIZE = 20 + ACK_SUCCESS = b"OK___" + ACK_ERROR = b"ERROR" + ACK_BUSY = b"BUSY_" + ACK_STOP = b"STOP_" + ACK_STOP_CONFIRM = b"OVER_" + ACK_KILL_PROCESS = b"KILL_" + + QUEUE_PENDING_TIME = 600 # 队列10分钟都处于阻塞状态,则终止sending进程 + RESEND_RETRY_TIMES = 2 # 最大重传数 + RESEND_TIMER_TIME = 5 # 接收ACK超时定时器 + RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据 + + def __init__(self, host="localhost", port=8000, check_sum=False): + self.send_queue = Queue(self.MAX_SENDING_QUEUE_SIZE) + self.resend_dict = dict() + self.host = host + self.port = port + self.factory = None + self.sequence_number = 0 + self.signal_exit = False + self.tcp_manager = ClientProtocol(ack_queue_size=100, + chunk_size=655360, + check_sum=check_sum) + self.send_thread = Thread(target=self._sending_queue_data) + self.send_thread.setDaemon(True) + self.send_thread.start() + self.destroy_thread = Thread(target=self._destroy_queue_data) + self.destroy_thread.setDaemon(True) + self.destroy_thread.start() + + def _ready_to_exit(self): + return self.signal_exit or self.tcp_manager.signal_exit + + def start(self): + def conn_callback(cur_protocol): + if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host: + logger.debug(f"SUCCESSFULLY 当前进程 {os.getpid()}") + else: + logger.debug(f"FAIL 当前进程 {os.getpid()}") + raise ConnectionError(f"Failed to connect to {self.host}.") + + def cur_protocol(): + return self.tcp_manager + + self.factory = MessageClientFactory() + self.factory.protocol = cur_protocol + + endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port) + d = endpoint.connect(self.factory) + d.addCallback(conn_callback) + + reactor_thread = Thread(target=self.run_reactor, daemon=True) + reactor_thread.start() + + def run_reactor(self): + reactor.run(installSignalHandlers=False) + + def send_after_queue_empty(self, data): + while not self._ready_to_exit(): + self.add_to_sending_queue(data) + time.sleep(2) + + def check_client_alive(self): + return self.factory.num_connections > 0 + + def stop(self): + self.tcp_manager.connection_timeout() + + def send_stop_signal(self): + self.send_after_queue_empty(self.ACK_STOP) + while not self._ready_to_exit(): + if not self.check_client_alive(): + break + time.sleep(1) + while not self.tcp_manager.kill_process: + time.sleep(1) + + def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], + rank: int = 0, step: int = 0): + if self._ready_to_exit(): + return + + send_data = data + if not isinstance(data, TCPDataItem): + send_data = TCPDataItem(data=data, + sequence_number=self.sequence_number, + rank=rank, + step=step) + self.sequence_number += 1 + + self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME) + + def _send_data(self, data: TCPDataItem): + self.tcp_manager.send_wrapped_data(data.raw_data, + sequence_number=data.sequence_number, + rank=data.rank, + step=data.step + ) + + @staticmethod + def get_obj_key(data: TCPDataItem): + return str(data.sequence_number) + "_" + str(data.rank) + "_" + str(data.step) + + def _sending_queue_data(self): + while True: + if not self.tcp_manager.is_connected: + continue + + while self.send_queue.qsize() > 0: + if self._ready_to_exit(): + break + if len(self.resend_dict) < self.MAX_SENDING_QUEUE_SIZE: + data_obj = self.send_queue.get() + self._send_data(data_obj) + resend_key = self.get_obj_key(data_obj) + if resend_key not in self.resend_dict.keys(): + # Send data for the first time + self.resend_dict[resend_key] = data_obj + else: + time.sleep(0.1) + + if self._ready_to_exit(): + logger.debug("Successfully close sending process.") + break + time.sleep(0.1) + + def _destroy_queue_data(self): + while True: + if self._ready_to_exit(): + break + + while len(self.resend_dict) > 0 and self.tcp_manager.ack_queue.qsize() > 0: + ack_info, seq_number, rank, step = self.tcp_manager.ack_queue.get() + obj_key = str(seq_number) + "_" + str(rank) + "_" + str(step) + current_item = self.resend_dict.get(obj_key) + + if current_item is None: + continue + + if ack_info == self.ACK_SUCCESS: + self.resend_dict.pop(obj_key) + elif ack_info == self.ACK_BUSY: + logger.debug("RECV BUSY ACK") + if current_item.busy_time > 5: + self._resend_data(current_item) + else: + current_item.busy_time += 1 + elif ack_info == self.ACK_ERROR: + logger.debug("RECV ERROR ACK") + self._resend_data(current_item) + elif ack_info == self.ACK_STOP_CONFIRM: + logger.debug("RECV STOP ACK") + self.factory.num_connections -= 1 + + break + + time.sleep(0.1) + + def _resend_data(self, data: TCPDataItem): + if data.retry_times < self.RESEND_RETRY_TIMES: + data.retry_times += 1 + logger.debug(f"Resend data seq number: {data.sequence_number}") + self.add_to_sending_queue(data) + else: + self.resend_dict.pop(data.sequence_number) + logger.debug(f"SKIP send sequence number {data.sequence_number} after retry {data.retry_times} times!") + + def _pending_data(self, data: TCPDataItem): + if data.pending_time >= self.RESEND_PENDING_TIME: + self.resend_dict.pop(data.sequence_number) + logger.debug(f"SKIP send sequence number {data.sequence_number} after pending {data.pending_time} times!") + return + + pending_time = self._get_pending_time(data) + data.pending_time += pending_time + time.sleep(pending_time) + + @staticmethod + def _get_pending_time(data: TCPDataItem) -> int: + # wait time is 100MB per second + return max(1, len(data.raw_data) // (2 ** 20 * 50)) + + +class ClientProtocol(protocol.Protocol): + TIMEOUT = 60 * 10 + + def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False): + self.buffer = io.BytesIO() + self.is_connected = False + self.check_sum = check_sum + self.tell = 0 + self.ack_queue = Queue(maxsize=ack_queue_size) + self.file_sender = FileSender() + self.file_sender.CHUNK_SIZE = chunk_size + self.signal_exit = False + self.defer = None + self.kill_process = False + + def dataReceived(self, data): + if self.timeout_call.active(): + self.timeout_call.reset(self.TIMEOUT) + + self.buffer.seek(0, 2) + self.buffer.write(data) + self.buffer.seek(self.tell) + while True: + if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3 + ack = self.buffer.read(5) + seq_number = struct.unpack('!Q', self.buffer.read(8))[0] + rank = struct.unpack('!Q', self.buffer.read(8))[0] + step = struct.unpack('!Q', self.buffer.read(8))[0] + if ack == b"KILL_": + self.kill_process = True + logger.debug(f"接收到KILL信号, PID {os.getpid()}") + if ack == b"OVER_": + self.factory.num_connections -= 1 + self.tell += 29 + if not self.ack_queue.full(): + self.ack_queue.put((ack, seq_number, rank, step)) + self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:]) + self.tell = 0 + else: + time.sleep(0.1) + else: + break + + def wrap_data(self, data): + length = len(data) + md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else "" + return length.to_bytes(8, byteorder='big'), md5_hash.encode() + + def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0): + length, md5_hash = self.wrap_data(data) + while True: + if self.defer is None or self.defer.called: + self.defer = self.send_large_data(length + + sequence_number.to_bytes(8, byteorder='big') + + rank.to_bytes(8, byteorder='big') + + step.to_bytes(8, byteorder='big') + + md5_hash + + data) + break + else: + time.sleep(0.01) + + def send_large_data(self, data): + d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport) + return d + + def connection_timeout(self): + if self.factory.num_connections <= 0: + return + + self.factory.num_connections -= 1 + logger.debug(f"超时退出{self.transport.addr}, PID {os.getpid()}") + self.transport.loseConnection() + + def connectionMade(self): + self.timeout_call = reactor.callLater(self.TIMEOUT, self.connection_timeout) + self.is_connected = True + self.factory.num_connections += 1 + print_info_log("successfully connect server") + + def connectionLost(self, reason): + self.signal_exit = True + self.factory.num_connections -= 1 + print_info_log("Lost connection with server") + + +class MessageClientFactory(protocol.ClientFactory): + def __init__(self): + self.num_connections = 0 + + def clientConnectionFailed(self, connector, reason): + print_info_log(f"Fail to connection with server: {reason.getErrorMessage()}") + reactor.stop() + + def clientConnectionLost(self, connector, reason): + print_info_log(f"Client lost connection with server: {reason.getErrorMessage()}") + reactor.stop() diff --git a/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/device_dispatch.py b/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/device_dispatch.py new file mode 100644 index 000000000..4a46aaa1a --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/device_dispatch.py @@ -0,0 +1,104 @@ +import time + +import torch +import torch.multiprocessing as mp + +from api_accuracy_checker.tensor_transport_layer.attl import move2target_device +from api_accuracy_checker.common.utils import print_error_log, print_warn_log, \ + print_info_log, logger + + +def run_ut_process(xpu_id, compare, consumer_queue, func, config): + device = torch.device(f'cuda:{xpu_id}') + + while True: + if consumer_queue.empty(): + time.sleep(0.1) + continue + + api_data = consumer_queue.get() + if api_data == "KILL_": + return + + api_full_name = api_data.name + api_data = move2target_device(api_data, device) + try: + data_info = func(api_full_name, api_data, config.backward_content) + logger.debug(f"success exec in device {api_full_name}") + is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info) + print_info_log(f"running api_full_name {api_full_name} ut, " + f"is_fwd_success: {is_fwd_success}, " + f"is_bwd_success: {is_bwd_success}") + except Exception as err: + [_, api_name, _] = api_full_name.split("*") + if "expected scalar type Long" in str(err): + print_warn_log(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " + f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") + else: + print_error_log(f"Run {api_full_name} UT Error: {str(err)}") + + compare.write_summary_csv((api_full_name, "SKIP", "SKIP", str(err), api_data.rank)) + + finally: + torch.cuda.empty_cache() + + +class ConsumerDispatcher: + def __init__(self, compare, capacity=10, num_workers=8, device: str = "gpu") -> None: + self.num_workers = num_workers + self.capacity = capacity + self.compare = compare + self.queues = [] + self.reverse_sort = False + self.pool = None + self.device = device + self.data_id = 0 + self.lock = mp.Lock() + self.result_queue = mp.Queue() + mp.set_start_method("spawn", force=True) + + def start(self, handle_func, config): + self.processes = [] + self.queues = [mp.Queue(maxsize=self.capacity) for _ in range(self.num_workers)] + for xpu_id, q in enumerate(self.queues): + p = mp.Process(name="run_ut_process", target=run_ut_process, + args=(xpu_id, self.compare, q, handle_func, config)) + + p.start() + self.processes.append(p) + print_info_log("Successfully start unittest process.") + + def update_consume_queue(self, api_data): + while True: + index = self._choose_max_empty_site_strategy() + if index != -1: + q = self.queues[index] + q.put(api_data) + logger.debug(f"将{api_data.name}调度给第{index}个GPU") + break + logger.debug("所有的UT队列都已满, 阻塞中") + time.sleep(0.1) + + def _choose_max_empty_site_strategy(self): + maximum = 0 + index = -1 + # 充分利用多卡资源,防止任务过多分配给前面的卡 + _reverse = 1 if not self.reverse_sort else -1 + for i, q in enumerate(self.queues[::_reverse]): + empty_site = self.capacity - q.qsize() + if empty_site > maximum: + maximum = empty_site + index = i + index = len(self.queues) - index - 1 if index != -1 and self.reverse_sort else index + self.reverse_sort = not self.reverse_sort + return index + + def stop(self): + for q in self.queues: + while q.full(): + time.sleep(0.1) + q.put("KILL_") + + for p in self.processes: + p.join() + print_info_log("Successfully stop unittest process.") diff --git a/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/server.py b/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/server.py new file mode 100644 index 000000000..29ceb8f93 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/server.py @@ -0,0 +1,216 @@ +import struct +import hashlib +import time +import io +from threading import Thread +from twisted.internet import reactor, protocol, endpoints +from api_accuracy_checker.common.utils import logger, print_info_log + + +class TCPServer: + def __init__(self, port, shared_queue, check_sum=False) -> None: + self.port = port + self.shared_queue = shared_queue + self.check_sum = check_sum + self.factory = MessageServerFactory() + self.reactor_thread = None + + def start(self): + self.factory.protocol = self.build_protocol + endpoint = endpoints.TCP4ServerEndpoint(reactor, self.port) + endpoint.listen(self.factory) + self.reactor_thread = Thread(target=self.run_reactor, daemon=True) + self.reactor_thread.start() + + def is_running(self): + return not self.factory.is_all_connection_closed() + + def stop(self): + self.factory.doStop() + reactor.callFromThread(reactor.sigInt, 2) + self.reactor_thread.join() + + @staticmethod + def run_reactor(): + reactor.run(installSignalHandlers=False) + + def build_protocol(self): + return ServerProtocol(self.shared_queue, self.check_sum) + + +class ServerProtocol(protocol.Protocol): + ACK_SUCCESS = b"OK___" + ACK_ERROR = b"ERROR" + ACK_BUSY = b"BUSY_" + ACK_STOP = b"STOP_" + ACK_STOP_CONFIRM = b"OVER_" + ACK_KILL_PROCESS = b"KILL_" + + def __init__(self, shared_queue, check_sum=False): + self.buffer = io.BytesIO() + self.consumer_queue = shared_queue + self.check_sum = check_sum + self.length_width = 8 + self.md5_width = 32 + self.obj_length = None + self.tell = 0 + self.obj_md5 = None + self.obj_body = None + self.sequence_number = -1 + self.rank = -1 + self.step = -1 + self.sequence_number_dict = dict() + + def connectionMade(self): + self.buffer = io.BytesIO() + self.obj_length = None + self.tell = 0 + self.obj_md5 = None + self.obj_body = None + self.factory.transport_dict[self.transport] = 1 + self.factory.transport_list.append(self.transport) + print_info_log(f"已连接客户端{self.transport.getPeer()}") + + def connectionLost(self, reason): + self.factory.transport_dict.pop(self.transport, None) + if len(self.factory.transport_dict) == 0: + self.consumer_queue.put(b'KILL_') + + print_info_log(f"REASON: {reason} 与客户端{self.transport.getPeer()} 断开连接, " + f"self.factory.transport_dict: {len(self.factory.transport_dict)}") + + def send_ack(self, ack_info): + self.transport.write(ack_info) + + def post_process(self): + send_busy_ack = False + while self.consumer_queue.full(): + if not send_busy_ack: + self.send_ack(self.ACK_BUSY + + self.sequence_number.to_bytes(8, byteorder='big') + + self.rank.to_bytes(8, byteorder='big') + + self.step.to_bytes(8, byteorder='big')) + logger.debug("sending BUSY ACK") + send_busy_ack = True + time.sleep(0.1) + + obj_key = str(self.sequence_number) + "_" + str(self.rank) + "_" + str(self.step) + + if self.check_sum: + recv_md5 = hashlib.md5(self.obj_body).hexdigest() + if recv_md5 == self.obj_md5: + if self.obj_body == self.ACK_STOP: + self.handle_with_stop() + else: + self.send_ack(self.ACK_SUCCESS + + self.sequence_number.to_bytes(8, byteorder='big') + + self.rank.to_bytes(8, byteorder='big') + + self.step.to_bytes(8, byteorder='big')) + if obj_key in self.sequence_number_dict: + logger.debug(f"这是一次异常的重传,可以忽略。 {obj_key}, {self.sequence_number_dict}") + else: + self.sequence_number_dict[obj_key] = self.obj_md5 + self.consumer_queue.put(self.obj_body, block=True) + else: + logger.debug( + f"Error: 接收数据有问题,流水号{self.sequence_number} : expected {self.obj_md5}, but get {recv_md5}") + + self.send_ack(self.ACK_ERROR + self.sequence_number.to_bytes(8, byteorder='big') + + self.rank.to_bytes(8, byteorder='big') + + self.step.to_bytes(8, byteorder='big')) + else: + if self.obj_body == self.ACK_STOP: + self.handle_with_stop() + else: + self.send_ack(self.ACK_SUCCESS + self.sequence_number.to_bytes(8, byteorder='big') + + self.rank.to_bytes(8, byteorder='big') + + self.step.to_bytes(8, byteorder='big')) + if obj_key in self.sequence_number_dict: + logger.debug("这是一次异常的重传,可以忽略。 {obj_key}, {self.sequence_number_dict}") + else: + self.sequence_number_dict[obj_key] = self.obj_md5 + self.consumer_queue.put(self.obj_body, block=True) + + self.reset_env() + finish_time = time.time() + logger.debug(f"finish_time: {finish_time - self.start_time}") + + def handle_with_stop(self): + logger.debug(f"接收到停止传输信号 TCP{self.transport.getPeer()}") + self.send_ack(self.ACK_STOP_CONFIRM + + self.sequence_number.to_bytes(8, byteorder='big') + + self.rank.to_bytes(8, byteorder='big') + + self.step.to_bytes(8, byteorder='big')) + if len(self.factory.transport_dict) == 0: + _rank, _step, _sequence_number = 0, 0, 100000000 + ack_kill = self.ACK_KILL_PROCESS + \ + _sequence_number.to_bytes(8, byteorder='big') + \ + _rank.to_bytes(8, byteorder='big') + \ + _step.to_bytes(8, byteorder='big') + for trans in self.factory.transport_list: + trans.write(ack_kill) + logger.debug(f"发送KILL信息给{self.transport.getPeer()}") + self.consumer_queue.put(b'KILL_') + time.sleep(2) + + def reset_env(self): + self.obj_length = None + self.sequence_number = -1 + self.rank = -1 + self.step = -1 + self.obj_md5 = None + self.obj_body = None + + def dataReceived(self, data): + self.buffer.seek(0, 2) + self.buffer.write(data) + self.buffer.seek(self.tell) + while True: + if self.obj_length is None and len(self.buffer.getvalue()) >= self.length_width * 4: + # 解析长度信息 + self.start_time = time.time() + self.obj_length = struct.unpack('!Q', self.buffer.read(self.length_width))[0] + self.sequence_number = struct.unpack('!Q', self.buffer.read(self.length_width))[0] + self.rank = struct.unpack('!Q', self.buffer.read(self.length_width))[0] + self.step = struct.unpack('!Q', self.buffer.read(self.length_width))[0] + self.tell += self.length_width * 4 + logger.debug( + f"流水号: {self.sequence_number}; RANK: {self.rank}; STEP: {self.step}; Length: {self.obj_length}") + + check_sum_and_md5 = self.check_sum and self.obj_length is not None and self.obj_md5 is None and len( + self.buffer.getvalue()) - self.tell >= self.md5_width + if check_sum_and_md5: + # 提取数据包 + self.obj_md5 = self.buffer.read(self.md5_width).decode() + self.tell += self.md5_width + logger.debug(f"MD5: {self.obj_md5}") + + current_length = len(self.buffer.getvalue()) - self.tell + + if self.obj_length is not None and 0 < self.obj_length <= current_length: + self.obj_body = self.buffer.read(self.obj_length) + + self.tell += self.obj_length + self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:]) + self.buffer.seek(0) + self.tell = 0 + recv_data_time = time.time() + logger.debug(f"self.sequence_number {self.sequence_number} " + f"recv_data_time {recv_data_time - self.start_time}") + + if self.obj_body == self.ACK_STOP: + _transport = self.factory.transport_dict.pop(self.transport, None) + logger.debug(f"接收到b'STOP_' self.sequence_number {self.sequence_number} ") + self.post_process() + break + else: + break + + +class MessageServerFactory(protocol.ServerFactory): + def __init__(self) -> None: + self.transport_dict = {} + self.transport_list = [] + + def is_all_connection_closed(self): + return len(self.transport_dict) == 0 diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_compare.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_compare.py index 2c9b13c4d..7504b5046 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_compare.py @@ -86,8 +86,8 @@ class TestCompare(unittest.TestCase): def test_record_results(self): args = ('Functional*conv2d*0', False, 'N/A', [['torch.float64', 'torch.float32', (32, 64, 112, 112), 1.0, 0.012798667686, 'N/A', 0.81631212311, 0.159979121213, 'N/A', - 'error', '\n']], None) - self.compare.record_results(*args) + 'error', '\n']], None, 0) + self.compare.record_results(args) with open(self.details_csv_path, 'r') as file: csv_reader = csv.reader(file) next(csv_reader) diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py index 3c180fa23..3412620f7 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py @@ -62,10 +62,10 @@ class TestRunUtMethods(unittest.TestCase): def test_UtDataInfo(self): data_info = UtDataInfo(None, None, None, None, None, None, None) - self.assertIsNone(data_info.bench_grad) - self.assertIsNone(data_info.device_grad) - self.assertIsNone(data_info.device_output) - self.assertIsNone(data_info.bench_output) + self.assertIsNone(data_info.bench_grad_out) + self.assertIsNone(data_info.device_grad_out) + self.assertIsNone(data_info.device_out) + self.assertIsNone(data_info.bench_out) self.assertIsNone(data_info.grad_in) self.assertIsNone(data_info.in_fwd_data_list) self.assertIsNone(data_info.backward_message) -- Gitee