diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py index 77bfa3f5e167a83d629a5e1cbc6980a4851bd1d4..87e197bfcb9789d0e6ca490998fdd59f2dc9bdae 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py @@ -44,7 +44,7 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv" class PtdbgDispatch(TorchDispatchMode): - def __init__(self, dump_mode=Const.OFF, api_list=None, debug=False, dump_path=None, tag=None, process_num=0): + def __init__(self, dump_mode=Const.OFF, api_list=[], debug=False, dump_path=None, tag=None, process_num=0): super(PtdbgDispatch, self).__init__() logger.info(COMPARE_LOGO) if not is_npu: @@ -182,7 +182,12 @@ class PtdbgDispatch(TorchDispatchMode): npu_out_cpu = safe_get_value(npu_out_cpu, 0, "npu_out_cpu") with TimeStatistics("CPU RUN", run_param): - cpu_out = func(*cpu_args, **cpu_kwargs) + try: + cpu_out = func(*cpu_args, **cpu_kwargs) + except RuntimeError: + self.api_index -= 1 + logger.warning(f"This aten_api {aten_api} does not support running on cpu, so skip it.") + return npu_out if isinstance(cpu_out, torch.Tensor) and cpu_out.dtype in [torch.bfloat16, torch.float16, torch.half]: cpu_out = cpu_out.float()