diff --git a/debug/accuracy_tools/msprobe/README.md b/debug/accuracy_tools/msprobe/README.md
index 5e3a0fa70ff3455b724bf6b283b46030d32bf783..7b612fbf3e2277e145b6bd471c809853ab541c45 100644
--- a/debug/accuracy_tools/msprobe/README.md
+++ b/debug/accuracy_tools/msprobe/README.md
@@ -84,7 +84,7 @@ msprobe 通过在训练脚本中添加 PrecisionDebugger 接口的方式对 API
精度预检旨在昇腾 NPU 上扫描训练模型中的所有 API 进行 API 复现,给出精度情况的诊断和分析。对应 config.json 中的 "run_ut" task。
-PyTorch 场景的[离线预检](./docs/07.accuracy_checker_PyTorch.md)和[在线预检](./docs/08.accuracy_checker_online_PyTorch.md)
+PyTorch 场景的[离线预检](./docs/07.accuracy_checker_PyTorch.md)
MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.md)
diff --git a/debug/accuracy_tools/msprobe/core/hook_manager.py b/debug/accuracy_tools/msprobe/core/hook_manager.py
index 5e5f7a62e6280c52dfb6c20d9840d9e866298f47..60074a4410d792e15852f8e2335333910d03cf37 100644
--- a/debug/accuracy_tools/msprobe/core/hook_manager.py
+++ b/debug/accuracy_tools/msprobe/core/hook_manager.py
@@ -36,10 +36,9 @@ class BaseHookManager(ABC):
hook_handle_dict = {}
params_grad_info = {}
- def __init__(self, data_collector, config, attl_manager=None):
+ def __init__(self, data_collector, config):
self.data_collector = data_collector
self.config = config
- self.attl_manager = attl_manager
@property
def _pid(self):
@@ -164,10 +163,7 @@ class BaseHookManager(ABC):
self._add_count(api_name)
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
self.data_collector.update_api_or_module_name(full_name)
- if getattr(self.config, "online_run_ut", False):
- BaseHookManager.inner_switch[tid] = False
- ThreadSafe.release()
- return
+
self.data_collector.forward_input_data_collect(
full_name,
module,
@@ -193,13 +189,7 @@ class BaseHookManager(ABC):
self.data_collector.update_api_or_module_name(full_name)
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
with self._no_grad_context():
- if getattr(self.config, "online_run_ut", False):
- if self.data_collector.scope and not self.data_collector.scope.check(full_name):
- return None
- if self.attl_manager:
- self.attl_manager.attl_send(full_name, args, kwargs, output)
- BaseHookManager.inner_switch[tid] = False
- return None
+
if hook_type == Const.MODULE:
params_dict = self._get_params_dict(module)
setattr(module_input_output, Const.PARAMS, params_dict)
@@ -243,9 +233,7 @@ class BaseHookManager(ABC):
with ThreadSafe():
BaseHookManager.inner_switch[tid] = True
self.data_collector.update_api_or_module_name(full_name)
- if getattr(self.config, "online_run_ut", False):
- BaseHookManager.inner_switch[tid] = False
- return
+
need_exchange = self._need_exchange(module) if hook_type == Const.MODULE else True
if need_exchange:
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
diff --git a/debug/accuracy_tools/msprobe/core/service.py b/debug/accuracy_tools/msprobe/core/service.py
index f435549d1958ba74c8026f0b01f086aca295d931..7b47665d0310d399d7e7f4af174570484ca84419 100644
--- a/debug/accuracy_tools/msprobe/core/service.py
+++ b/debug/accuracy_tools/msprobe/core/service.py
@@ -36,7 +36,6 @@ class BaseService(ABC):
self.config.level = getattr(config, 'level_ori', config.level) # 兼容MindSpore配置
self.model = None
self.data_collector = build_data_collector(self.config)
- self.attl_manager = None
self.current_iter = 0
self.loop = 0
self.init_step = 0
@@ -92,9 +91,6 @@ class BaseService(ABC):
(self.config.task == Const.STATISTICS and self.config.tensor_list)
)
- @property
- def _is_online_run_ut(self):
- return getattr(self.config, "online_run_ut", False)
@property
@abstractmethod
@@ -143,11 +139,9 @@ class BaseService(ABC):
self.primitive_switch = True
self._change_jit_switch(True)
self.logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
- if self._is_online_run_ut:
- self._run_ut_dispatch(True)
- else:
- self.create_dirs()
- self.logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
+
+ self.create_dirs()
+ self.logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
def stop(self):
"""通用stop模板"""
@@ -162,8 +156,7 @@ class BaseService(ABC):
self._change_jit_switch(False)
if self._is_l2_level:
return
- if self._is_online_run_ut:
- self._run_ut_dispatch(False)
+
self._process_async_dump()
self.data_collector.write_json()
@@ -263,8 +256,6 @@ class BaseService(ABC):
end_service = self.config.step and self.current_iter > max(self.config.step) or \
self.data_collector and self.data_collector.data_processor.is_terminated
if end_service:
- if self._is_online_run_ut and self.attl_manager:
- self.attl_manager.attl_stop()
self.primitive_switch = False
self._change_jit_switch(False)
Runtime.is_running = False
@@ -307,8 +298,7 @@ class BaseService(ABC):
if root_model and isinstance(root_model, list):
root_model = root_model[0]
self.logger.warning("Infer model can only input one to support token_range, choose the first one.")
- if self._is_online_run_ut:
- return
+
root_model.register_forward_pre_hook(infer_hook)
def _create_l2_dirs(self, cur_rank):
diff --git a/debug/accuracy_tools/msprobe/docs/02.config_introduction.md b/debug/accuracy_tools/msprobe/docs/02.config_introduction.md
index e00b9c86bd6fa1ba7d2b54292859fb6fb583172d..3efa541b53660c3fb8e7e63d7bb92b1d5db978e8 100644
--- a/debug/accuracy_tools/msprobe/docs/02.config_introduction.md
+++ b/debug/accuracy_tools/msprobe/docs/02.config_introduction.md
@@ -73,14 +73,7 @@
| data_mode | 与[ 1.2 task 配置为 statistics ](#12-task-配置为-statistics)中的解释相同 | 否 |
| file_format | tensor 数据的保存格式,str 类型,仅支持 MindSpore 静态图场景的 L2 级别配置该字段,其他场景不生效。可选参数:
"bin":dump 的 tensor 文件为二进制格式;
"npy":dump 的 tensor 文件后缀为 .npy,默认值。 | 否 |
| summary_mode | 控制 dump 文件输出的模式,str 类型,支持 PyTorch、MSAdapter、MindSpore 动态图。可选参数:
md5:dump 输出包含 CRC-32 值以及 API 统计信息的 dump.json 文件,用于验证数据的完整性;
statistics:dump 仅输出包含 API 统计信息的 dump.json 文件,默认值。| 否 |
-| online_run_uta | 在线预检模式开关,bool 类型,可选参数 true(开启)、false(关闭),默认未配置,表示关闭。配置为 true 表示开启在线预检。| 否 |
-| nfs_patha | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。仅在 online_run_ut 字段配置为 true 时生效,配置该参数后 host 和 port 不生效。 | 否 |
-| hosta | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的局域网 IP 地址。仅在 online_run_ut 字段配置为 true 时生效,局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
-| porta | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的端口号。仅在 online_run_ut 字段配置为 true 时生效,局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。| 否 |
-**说明**:
-
-1. online_run_ut、nfs_path、host、port 等字段仅在线预检场景 NPU 机器生效。
**示例**:
- [PyTorch场景](03.config_examples.md#12-task-配置为-tensor)
@@ -95,17 +88,11 @@
| white_lista | API dump 白名单,仅对指定的 API 进行 dump。
**配置示例**:"white_list": ["conv1d", "conv2d"]。默认未配置白名单,即 dump 全量 API 数据。 | 否 |
| black_lista | API dump 黑名单,被指定的 API 不进行 dump。
**配置示例**:"black_list": ["conv1d", "conv2d"]。默认未配置黑名单,即 dump 全量 API 数据。 | 否 |
| error_data_path | 配置保存精度未达标的 API 输入输出数据路径,默认为当前路径。
**配置示例**:"error_data_path": "./"。 | 否 |
-| is_onlineb | 在线预检模式开关,bool 类型,可选参数 true(开启)、false(关闭),默认关闭。 | 否 |
-| nfs_pathb | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。配置该参数后 host 和 port 不生效,仅在 is_online 字段配置为 true 时生效。 | 否 |
-| hostb | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机地址 127.0.0.1 或本机局域网 IP。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。仅在 is_online 字段配置为 true 时生效。 | 否 |
-| portb | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机可用端口。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。仅在 is_online 字段配置为 true 时生效。| 否 |
-| rank_listb | 指定在线预检的 Rank ID,默认值为 [0],list[int] 类型,应配置为大于等于 0 的整数,且须根据实际卡的 Rank ID 配置,若所配置的值大于实际训练所运行的卡的 Rank ID,则在线预检输出数据为空。GPU 和 NPU 须配置一致。仅在 is_online 字段配置为 true 时生效。 | 否 |
**说明**:
1. white_list 和 black_list 同时配置时,二者配置的 API 名单若无交集,则白名单生效,若 API 名单存在交集,则白名单排除的部分以及交集的 API 不进行 dump。
-2. is_online、nfs_path、host、port、rank_list 等字段仅在线预检场景 GPU 机器生效。
**示例**:
```json
diff --git a/debug/accuracy_tools/msprobe/docs/08.accuracy_checker_online_PyTorch.md b/debug/accuracy_tools/msprobe/docs/08.accuracy_checker_online_PyTorch.md
deleted file mode 100644
index 119827abef46914d2aa5e77d91c57813952e92f6..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/msprobe/docs/08.accuracy_checker_online_PyTorch.md
+++ /dev/null
@@ -1,295 +0,0 @@
-# PyTorch 场景的在线精度预检
-
-## 1 简介
-
-为了应对大模型场景下,通过离线预检方式 dump API 输入输出数据导致的存储资源紧张问题,提供在线精度预检功能。本功能实现在执行 NPU 训练操作的过程中,通过 TCP/IP 协议在 NPU
-Host 与 GPU Host 设备间建立连接,将 NPU 上对应 API 的输入数据在 GPU 设备上运行,将两份输出数据进行比对,得到预检比对结果,从而减少数据 dump 的步骤,降低存储资源的占用。针对偏差较大的算子,两方比对(NPU vs. GPU)的方法缺少裁判进行裁定。 参考离线预检,在线预检场景同时支持两方比对和三方比对方式,按照 api 的精度标准要求,选择两方比对或三方比对。
-
-## 2 在线精度预检流程
-
-在线精度预检当前支持**局域网场景**和**共享存储场景**,请根据不同的场景选择对应的配置。
-
-在线精度预检操作流程如下:
-
-1. 准备 GPU 和 NPU 可正常运行的训练环境,PyTorch 版本大于等于2.0,并保证两台 Host 在同一局域网内可正常通信或能通过共享存储进行通信。
-2. GPU 和 NPU Host 设备上同时安装msprobe工具,详见[ msprobe 安装](./01.installation.md)章节,其中在线预检要安装 twisted、pyOpenSSL,这些包为 Python 模块。
-3. 分别配置 GPU 侧、NPU 侧的 config.json 文件。
-4. 在 GPU 侧运行 `msprobe -f pytorch run_ut -config ./config.json`。
-5. 在 NPU 侧配置训练脚本。
-6. 在 NPU 侧执行训练。
-
-## 3 在线精度预检操作指导
-
-### 3.1 配置 config.json 文件
-
-预检工具安装完成后,需要在 GPU 和 NPU 环境下分别配置 config.json。其中需要重点关注文件中的 is_online、is_benchmark_device、host 和 port 参数的配置,保障在线预检时 GPU 和 NPU 两台设备间的通信正常。
-
-#### 3.1.1 GPU 侧在线预检配置说明
-
-| 参数名称 | 说明 | 是否必选 |
-|-----------------|--------------|------|
-| task | 任务名称,str 类型,配置为 run_ut 表示预检任务。通过其他字段 is_online 判断离线预检、在线预检任务。 | 是 |
-| white_list | 预检的 API 白名单,list[str] 类型。
**配置示例**:white_list=["conv1d", "conv2d"]。默认未配置白名单,即预检全量 API 数据。 | 否 |
-| black_list | 预检的 API 黑名单,list[str] 类型。
**配置示例**:white_list=["conv1d", "conv2d"]。默认未配置黑名单,即预检全量 API 数据。 | 否 |
-| error_data_path | 配置保存精度未达标的 API 输入输出数据路径,str 类型。在线预检模式下该参数不生效。 | 否 |
-| is_online | 在线预检模式开关,bool 类型,可取值 True(开启)、False(关闭),默认关闭。 | 是 |
-| nfs_path | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。配置该参数后 host、port 和 tls_path 不生效。 | 否 |
-| host | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机地址 127.0.0.1 或本机局域网 IP。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
-| port | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机可用端口。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
-| rank_list | 指定在线预检的 Rank ID,默认值为 [0],list[int] 类型,应配置为大于等于 0 的整数,且须根据实际卡的 Rank ID 配置,若所配置的值大于实际训练所运行的卡的 Rank ID,则在线预检输出数据为空。GPU 和 NPU 须配置一致。 | 是 |
-| tls_path | 在线预检模式局域网场景 SSL 证书路径,该路径下包含私钥 server.key、证书 server.crt、自建CA证书 ca.crt、CRL吊销证书 crl.pem,str 类型,未配置该参数时默认取值当前路径。tls_path配置为空字符串时,采用TCP协议明文传输api数据;当配置为路径时,采用TLS1.2协议加密传输数据,加密传输时安全性较高,传输速率较低。其中 crl.pem 为非必需文件,仅当用户存在吊销记录时使用。 | 否 |
-
-
-#### 3.1.2 NPU 侧在线预检配置说明
-
-| 参数名称 | 说明 | 是否必选 |
-|------------------|-------------|------|
-| task | 任务名称,str 类型,配置为 tensor 表示 dump API 统计信息和完全复刻整网的 API 运行情况的真实数据。通过字段 online_run_ut 判断是否使用在线预检功能。 | 是 |
-| dump_path | dump 路径,str 类型,配置为合法路径即可,兼容 tensor 任务静态检查。 | 是 |
-| level | dump 级别,str 类型,在线预检时配置为 L1,表示 dump API 级精度数据。在线预检可不配置,默认取值 L1。 | 是 |
-| rank | 指定对某张卡上的数据进行 dump,list[int] 类型,默认未配置(表示 dump所有卡的数据),需要与 GPU 侧配置项 rank_list 保持一致。 | 否 |
-| step | 指定 dump 某个 step 的数据,list[int] 类型,默认未配置,表示 dump 所有 step 的数据。dump 特定 step 时,须指定为训练脚本中存在的 step。 | 否 |
-| scope | dump 范围,list[str] 类型,默认未配置(list 也未配置时表示 dump 所有 api 的数据),配置方式参考 [config.json 配置介绍](./02.config_introduction.md)。 | 否 |
-| list | dump 范围,list[str] 类型,默认未配置(scope 也未配置时表示 dump 所有 api 的数据),配置方式参考 [config.json 配置介绍](./02.config_introduction.md)。 | 否 |
-| online_run_ut | 在线预检模式开关,bool 类型,可取值 True(开启)、False(关闭),默认关闭。 | 是 |
-| nfs_path | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。配置该参数后 host 和 port 不生效。 | 否 |
-| host | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的局域网 IP 地址。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
-| port | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的端口号。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
-| tls_path | 在线预检模式局域网场景 SSL 证书路径,该路径下包含私钥 client.key、证书 client.crt、自建CA证书 ca.crt、CRL吊销证书 crl.pem,str 类型,未配置该参数时默认取值当前路径。tls_path配置为空字符串时,采用TCP协议明文传输api数据;当配置为路径时,采用TLS1.2协议加密传输数据,加密传输时安全性较高,传输速率较低。其中 crl.pem 为非必需文件,仅当用户存在吊销记录时使用。 | 否 |
-| online_run_ut_recompute | 模型训练是否使用重计算机制,bool类型,默认为False,表示模型没有使用重计算。在线预检暂不支持重计算机制下反向算子的预检,当模型训练使用重计算时,跳过反向算子预检,默认模型关闭重计算。 | 否 |
-
-#### 3.1.3 局域网场景配置示例
-
-若采用 TLS1.2 协议加密传输 api 数据,需配置 SSL 证书,可参考如下生成自签名证书方法。
-
-以下秘钥生成方法仅为简单示例,客户应使用与自己需求相符的秘钥生成和存储机制并保证秘钥安全性与机密性,必要时可采用分层秘钥机制。
-以下示例中加密口令仅供参考,使用时请更换为复杂口令,并保护口令安全。
-```shell
-# 生成CA证书的根私钥和证书签名请求,其中ca_password为CA私钥加密口令,仅作演示,请更换使用
-openssl req -new -newkey rsa:3072 -passout pass:ca_password -subj "/CN=*ca.com/O=ca.Inc./C=CN/ST=Zhejiang/L=Hangzhou" -keyout ca.key -out ca.csr
-# 自签发根证书
-openssl x509 -req -days 365 -in ca.csr -signkey ca.key -passin pass:ca_password -out ca.crt -extensions v3_ca -extfile <(cat <<-EOF
-[v3_ca]
-basicConstraints = critical,CA:true
-keyUsage = critical, keyCertSign, cRLSign
-EOF
-)
-
-# 生成client公私钥,其中client_password为私钥加密口令,仅作演示,请更换使用
-openssl genrsa -aes256 -passout pass:client_password -out client.key 3072
-# 基于client公私钥生成签名请求
-openssl req -new -key client.key -passin pass:client_password -subj "/CN=*example.com/O=Test, Inc./C=CN/ST=Zhejiang/L=Hangzhou" -out client.csr
-# 利用自签发的根证书,签发client证书
-openssl x509 -req -days 180 -CA ca.crt -CAkey ca.key -passin pass:ca_password -in client.csr -out client.crt -CAcreateserial -extfile <(cat <<-EOF
-[v3_server]
-basicConstraints = CA:FALSE
-keyUsage = critical, digitalSignature, keyEncipherment
-extendedKeyUsage = serverAuth
-EOF
-)
-
-# 生成server公私钥,其中server_password为私钥加密口令,仅作演示,请更换使用
-openssl genrsa -aes256 -passout pass:server_password -out server.key 3072
-# 基于server公私钥生成签名请求
-openssl req -new -key server.key -passin pass:server_password -subj "/CN=*example.com/O=Test, Inc./C=CN/ST=Zhejiang/L=Hangzhou" -out server.csr
-# 利用自签发的根证书,签发server证书
-openssl x509 -req -days 180 -CA ca.crt -CAkey ca.key -passin pass:ca_password -in server.csr -out server.crt -CAcreateserial -extfile <(cat <<-EOF
-[v3_server]
-basicConstraints = CA:FALSE
-keyUsage = critical, digitalSignature, keyEncipherment
-extendedKeyUsage = serverAuth
-EOF
-)
-
-```
-
-当需要吊销已创建的SSL证书时,通过openssl命令生成CRL证书 crl.pem,示例如下:
-```shell
-# 创建证书信息的文本数据库,空文件即可
-touch index.txt
-
-# 创建ca配置文件ca.cnf,内容如下,用于吊销证书使用
-[ca]
-default_ca = CA_default
-[CA_default]
-database = ./index.txt
-default_md = sha256
-
-# 吊销证书 client.crt,其中ca_password为CA私钥加密口令,与CA创建时保持一致
-openssl ca -revoke client.crt -config ca.cnf -cert ca.crt -keyfile ca.key -passin pass:ca_password
-# 生成CRL文件
-openssl ca -gencrl -config ca.cnf -cert ca.crt -keyfile ca.key -passin pass:ca_password -out crl.pem -crldays 30
-# 查看生成的CRL文件内容:
-openssl工具的命令: openssl crl -inform PEM -in crl.pem -text
-
-```
-
-注意:配置TLS协议时,传输性能受机器环境和网络质量的影响,可能触发NPU超时中断模型训练,为避免训练和预检中断,丢弃长时间未传输的api数据,同时NPU侧配置HCCL环境变量,配置方式如下:
-
-a) 调整HCCL环境变量,关闭看门狗,避免WorkHCCL超时中断模型训练:
-```shell
-export HCCL_DESYNC_DEBUG=0
-export HCCL_ASYNC_ERROR_HANDLING=0
-```
-b) 调整通信算子超时设置(以1800s举例):
-```shell
-export HCCL_CONNECT_TIMEOUT=1800
-export HCCL_EXEC_TIMEOUT=1800
-```
-
-GPU 侧:
-
-```json
-{
- "task": "run_ut",
- "run_ut": {
- "white_list": [],
- "black_list": [],
- "error_data_path": "./",
- "is_online": true,
- "nfs_path": "",
- "host": "127.0.0.1",
- "port": 59208,
- "rank_list": [0],
- "tls_path": ""
- }
-}
-```
-
-NPU 侧:
-
-```json
-{
- "task": "tensor",
- "dump_path": "./dump_path",
- "rank": [0],
- "step": [0],
- "level": "L1",
- "tensor": {
- "scope": [],
- "list": [],
- "online_run_ut": true,
- "nfs_path": "",
- "host": "xx.xx.xx.x",
- "port": 59208,
- "tls_path": ""
- }
-}
-```
-
-#### 3.1.4 共享存储场景配置示例
-
-GPU 侧:
-
-```json
-{
- "task": "run_ut",
- "run_ut": {
- "white_list": [],
- "black_list": [],
- "error_data_path": "./",
- "is_online": true,
- "nfs_path": "/nfs/xxx/data",
- "host": "",
- "port": -1,
- "rank_list": [0],
- "tls_path": ""
- }
-}
-```
-
-NPU 侧:
-
-```json
-{
- "task": "tensor",
- "dump_path": "./dump_path",
- "rank": [0],
- "step": [0],
- "level": "L1",
- "tensor": {
- "scope": [],
- "list": [],
- "online_run_ut": true,
- "nfs_path": "/nfs/xxx/data",
- "host": "",
- "port": -1,
- "tls_path": ""
- }
-}
-```
-
-### 3.2 在 GPU 侧运行 run_ut
-
-由于 GPU 侧为通信接收端,需先于 NPU 侧执行 run_ut 操作,命令如下:
-
-```bash
-msprobe -f pytorch run_ut -config ./config.json
-```
-
-GPU 侧配置好 config.json 文件后执行 run_ut 命令,此时 GPU 处于预检等待状态:
-
-- 局域网场景:当 NPU 侧启动训练后将预检的 API 输入和输出数据发送到 GPU 侧时,GPU 启动预检操作。
-- 共享存储场景:当 NPU 侧启动训练后将预检的 API 输入和输出数据发送到共享存储时,GPU 启动预检操作。
-
-### 3.3 在 NPU 侧配置训练脚本
-
-在 NPU 训练脚本中添加如下代码以获取 run_ut 操作的预检 API 输入和输出数据:
-
-```python
-from msprobe.pytorch import PrecisionDebugger
-
-debugger = PrecisionDebugger("config.json")
-...
-
-debugger.start()
-
-...
-
-debugger.stop()
-debugger.step()
-```
-
-### 3.4 在 NPU 侧执行训练脚本
-
-配置完 NPU 侧训练脚本后即可执行训练脚本,命令示例如下:
-
-```bash
-bash train.sh
-```
-
-训练脚本执行完毕后,在GPU侧dump_path目录下生成比对结果文件,`accuracy_checking_result_{timestamp}_rank{rank_id}.csv`和`accuracy_checking_details_{timestamp}_rank{rank_id}.csv`记录两方比对结果,`api_precision_compare_result_{timestamp}_rank{rank_id}.csv`和`api_precision_compare_details_{timestamp}_rank{rank_id}.csv`记录三方比对结果。详细介绍请参见[离线精度预检中的 **4 预检结果**](./07.accuracy_checker_PyTorch.md#4-预检结果)。
-
-## 4 支持的融合算子列表
-
-预检工具当前支持的融合算子如下:
-
-- npu_apply_adam_w
-
-- npu_confusion_transpose
-
-- fast_gelu
-
-- npu_layer_norm_eval
-
-- npu_linear
-
-- npu_fusion_attention(该算子在 GPU 上预检时,需要额外安装 flash_attn,请用户自行安装。)
-
-- npu_rms_norm
-
-- npu_rotary_mul
-
-- npu_scaled_masked_softmax
-
-- npu_swiglu
-
-- npu_apply_adam
-
-- npu_group_norm_silu
-
-- npu_mish
-
-- npu_moe_gating_top_k_softmax
-
-- npu_sort_v2
\ No newline at end of file
diff --git a/debug/accuracy_tools/msprobe/docs/25.tool_function_introduction.md b/debug/accuracy_tools/msprobe/docs/25.tool_function_introduction.md
index 8ceaf3c4a3b5f5ea8c6a0e8aeeab56a7bbf39d52..686660b09991ae66e48a3423458fd239b1f01099 100644
--- a/debug/accuracy_tools/msprobe/docs/25.tool_function_introduction.md
+++ b/debug/accuracy_tools/msprobe/docs/25.tool_function_introduction.md
@@ -7,7 +7,6 @@
| [数据采集
(dump)](./05.data_dump_PyTorch.md) | 采集模型训练过程中的API或Module层级的前反向输入输出数据,包括层次关系、统计值信息、真实数据和调用栈等。 | 1、将模型中训练的API或Module的前反向输入输出数据保存下来分析
2、模型出现溢出时,可用于查看哪些API或Module出现了溢出 | 1、API级数据采集仅支持白名单列表上的API
2、工具会做一些同步操作,引入工具可能会导致一些同步问题消失
3、当前对inplace操作API或Module的支持度有限
4、暂不支持参数及参数梯度的采集 |
| [离线预检
(api_accuracy_checker)](./07.accuracy_checker_PyTorch.md) | 为网络中每个API创建用例,检验其精度,并根据不同比对算法综合判定API在NPU上的精度是否达标,快速找出精度差异API。 | 1、对模型中所有的API做精度初步排查
2、精度排查不受模型累计误差影响 | 1、依赖GPU环境
2、不支持通信算子
3、仅支持部分融合算子 |
| [整网比对
(compare)](./10.accuracy_compare_PyTorch.md) | 计算模型整网NPU和标杆设备的精度误差指标,标记精度异常API或Module,助力快速定位精度问题根因。 | 1、整网比对定位精度可疑算子 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响
2、当模型规模较大时,比对所需时间较长 |
-| [在线预检
(online_api_accuracy_checker)](./08.accuracy_checker_online_PyTorch.md) | 通过TCP通信或共享存储空间的方式,进行在线精度预检,解决离线预检大数据量落盘、传输困难痛点。 | 1、使用离线预检,数据量较大落盘困难或传输耗时长时,可通过在线预检进行精度排查 | 1、依赖GPU环境,NPU和GPU能够通信
2、重计算模式下,不支持反向aten算子预检 |
| [溢出检查
(overflow_checker)](./12.overflow_check_PyTorch.md) | 检测模型计算过程的输入输出,并在溢出时落盘数据,助力用户快速定位溢出位置。 | 1、当模型出现溢出时,用于快速定位最先溢出的API或Module
2、相比数据采集,性能更优,磁盘压力更小 | 1、局限性同数据采集 |
| [数据解析
(parse_tool)](./14.data_parse_PyTorch.md) | 交互式界面处理解析kernel层级dump数据,便于查看分析。 | 1、比对kernel层级dump数据的一致性 | 1、仅限于NPU |
| [无标杆比对
(free_benchmark)](./15.free_benchmarking_PyTorch.md) | 不依赖标杆数据,通过对算子输入增加微小扰动,计算扰动后输出与原始输出的相对误差,识别有精度风险算子。 | 1、无标杆数据场景下的算子精度排查
2、对个别算子进行升精度、“to cpu”等操作,以验证其对模型loss的影响 | 1、由于需要拷贝输入进行二次执行,所以在遇到大张量的输入时容易发生显存OOM的问题, 特别是反向比对过程。建议结合白名单使用
2、比对会延长训练时间,整网比对可能会造成严重的耗时膨胀,建议结合白名单使用 |
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py
index 1e844ff81a8543c9865dbefc3c39c12202d2c6e2..7c9beab9b281a567f80c0049c444c659bbd84d6c 100644
--- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py
@@ -24,8 +24,7 @@ from msprobe.pytorch.pt_config import RunUTConfig
RunUtConfig = namedtuple('RunUtConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
- 'black_list', 'error_data_path', 'online_config'])
-OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
+ 'black_list', 'error_data_path'])
class Config:
@@ -46,13 +45,7 @@ class Config:
'white_list': list,
'black_list': list,
'error_data_path': str,
- 'precision': int,
- 'is_online': bool,
- 'nfs_path': str,
- 'host': str,
- 'port': int,
- 'rank_list': list,
- 'tls_path': str
+ 'precision': int
}
if key not in validators:
raise ValueError(f"{key} must be one of {validators.keys()}")
@@ -68,10 +61,6 @@ class Config:
RunUTConfig.check_filter_list_config(key, value)
if key == 'error_data_path':
RunUTConfig.check_error_data_path_config(value)
- if key == 'nfs_path':
- RunUTConfig.check_nfs_path_config(value)
- if key == 'tls_path':
- RunUTConfig.check_tls_path_config(value)
return value
@@ -85,12 +74,6 @@ class CheckerConfig:
self.white_list = msCheckerConfig.white_list
self.black_list = msCheckerConfig.black_list
self.error_data_path = msCheckerConfig.error_data_path
- self.is_online = msCheckerConfig.is_online
- self.nfs_path = msCheckerConfig.nfs_path
- self.host = msCheckerConfig.host
- self.port = msCheckerConfig.port
- self.rank_list = msCheckerConfig.rank_list
- self.tls_path = msCheckerConfig.tls_path
if task_config:
self.load_config(task_config)
@@ -99,22 +82,7 @@ class CheckerConfig:
self.white_list = task_config.white_list
self.black_list = task_config.black_list
self.error_data_path = task_config.error_data_path
- self.is_online = task_config.is_online
- self.nfs_path = task_config.nfs_path
- self.host = task_config.host
- self.port = task_config.port
- self.rank_list = task_config.rank_list
- self.tls_path = task_config.tls_path
- def get_online_config(self):
- return OnlineConfig(
- is_online=self.is_online,
- nfs_path=self.nfs_path,
- host=self.host,
- port=self.port,
- rank_list=self.rank_list,
- tls_path=self.tls_path
- )
def get_run_ut_config(self, **config_params):
return RunUtConfig(
@@ -127,6 +95,5 @@ class CheckerConfig:
real_data_path=config_params.get('real_data_path'),
white_list=self.white_list.copy() if self.white_list else [],
black_list=self.black_list.copy() if self.black_list else [],
- error_data_path=config_params.get('error_data_path'),
- online_config=self.get_online_config()
+ error_data_path=config_params.get('error_data_path')
)
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py
index 55e93d271cec67334fe21c1f6466df2d0254a36b..24ac8b17ced04ea186898644925c77912dd294be 100644
--- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py
@@ -117,30 +117,6 @@ def api_precision_compare(config):
change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
-def online_api_precision_compare(online_config):
- rank = online_config.rank
- result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace(
- "_rank*.csv", f"_rank{rank}.csv")
- details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace(
- "_rank*.csv", f"_rank{rank}.csv")
- detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
- result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
- if not os.path.exists(result_csv_path):
- write_csv(result_csv_title, result_csv_path)
- if not os.path.exists(details_csv_path):
- write_csv(detail_csv_title, details_csv_path)
- config = CompareConfig("", "", result_csv_path, details_csv_path)
- try:
- npu_data, gpu_data = online_config.npu_data, online_config.gpu_data
- check_csv_columns(npu_data.columns, "npu_csv")
- check_csv_columns(gpu_data.columns, "gpu_csv")
- analyse_csv(npu_data, gpu_data, config)
- except Exception as err:
- logger.error(f"Online api precision compare Error: {str(err)}")
- change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
- change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
-
-
def analyse_csv(npu_data, gpu_data, config):
forward_status, backward_status = [], []
last_api_name, last_api_dtype, last_api_full_name = None, None, None
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py
index c12a54c18ad07ae302b41d12704dc82fec01b4c2..3387faaf96ec9eabb17d26ae98860c0d3b468ba0 100644
--- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py
@@ -66,13 +66,6 @@ class Comparator:
self.save_path_list = [result_csv_path]
self.detail_save_path_list = [details_csv_path]
- if config and config.online_config.is_online:
- self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv")
- self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv")
- self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list]
- self.detail_save_path_list = \
- [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
-
self.registry = self._register_compare_func()
if not is_continue_run_ut:
@@ -245,9 +238,8 @@ class Comparator:
self.write_detail_csv(args)
- def compare_output(self, full_api_name, data_info, is_online=False):
+ def compare_output(self, full_api_name, data_info):
"""Get compare result and write to result and detail csv.
- is_online: bool, default False. True: called by online api precision compare, only compare without write to csv.
"""
_, api_name = extract_basic_api_segments(full_api_name)
if not api_name:
@@ -280,9 +272,7 @@ class Comparator:
fwd_compare_alg_results,
bwd_compare_alg_results,
data_info.rank)
- if is_online:
- # get run_ut compare detail
- return self._get_run_ut_detail(result_info)
+
self.record_results(result_info)
return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
or bwd_success_status == CompareConst.SPACE
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml
index 2ec9251009e61ef68dbfed987abe457d47b91e9a..2797d0c64cccad149727c4c6a1b86c5cb4290350 100644
--- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml
@@ -2,9 +2,4 @@ white_list: []
black_list: []
error_data_path: './'
precision: 14
-is_online: False
-nfs_path: ""
-host: ""
-port: -1
-rank_list: [0]
-tls_path: "./"
+
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py
index 082f391c957578bad9b1dff546803aa7d4ce05b0..4bf8ead7e6a7f686f2cb2f457884f054ab6e5237 100644
--- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py
@@ -51,8 +51,6 @@ from msprobe.pytorch.pt_config import parse_json_config
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
from msprobe.pytorch.common.utils import seed_all
-from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
-from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params, \
ExecParams
@@ -90,27 +88,22 @@ seed_all()
def run_ut(config):
logger.info("start UT test")
- if config.online_config.is_online:
- logger.info(f"UT task result will be saved in {config.result_csv_path}".replace(".csv", "_rank*.csv"))
- logger.info(f"UT task details will be saved in {config.details_csv_path}".replace(".csv", "_rank*.csv"))
- else:
- logger.info(f"UT task result will be saved in {config.result_csv_path}")
- logger.info(f"UT task details will be saved in {config.details_csv_path}")
+
+ logger.info(f"UT task result will be saved in {config.result_csv_path}")
+ logger.info(f"UT task details will be saved in {config.details_csv_path}")
if config.save_error_data:
logger.info(f"UT task error_data will be saved in {config.error_data_path}")
compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
- if config.online_config.is_online:
- run_api_online(config, compare)
- else:
- csv_df = read_csv(config.result_csv_path)
- try:
- api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
- except IndexError:
- logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
- api_name_set = set()
- run_api_offline(config, compare, api_name_set)
+
+ csv_df = read_csv(config.result_csv_path)
+ try:
+ api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
+ except IndexError:
+ logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
+ api_name_set = set()
+ run_api_offline(config, compare, api_name_set)
for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -164,60 +157,6 @@ def run_api_offline(config, compare, api_name_set):
gc.collect()
-def run_api_online(config, compare):
- attl = init_attl(config.online_config)
- dispatcher = ConsumerDispatcher(compare=compare)
- dispatcher.start(handle_func=run_torch_api_online, config=config)
-
- def tcp_communication_flow():
- while True:
- api_data = attl.recv()
- if api_data == 'STOP_':
- continue
- if api_data == 'KILL_':
- time.sleep(1)
- logger.info("==========接收到STOP信号==========")
- dispatcher.stop()
- attl.stop_serve()
- time.sleep(1)
- break
- if not isinstance(api_data, ApiData):
- continue
- api_full_name = api_data.name
- _, api_name = extract_basic_api_segments(api_full_name)
- if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
- continue
- if api_data.rank in config.online_config.rank_list:
- dispatcher.update_consume_queue(api_data)
-
- def shared_storage_communication_flow():
- flag_num = -1
- while True:
- api_data = attl.download()
- if api_data == "start":
- if flag_num == -1:
- flag_num += 1
- flag_num += 1
- if api_data == "end":
- flag_num -= 1
- if flag_num == 0:
- dispatcher.stop()
- break
- if not isinstance(api_data, ApiData):
- continue
- api_full_name = api_data.name
- _, api_name = extract_basic_api_segments(api_full_name)
- if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
- continue
- if api_data.rank in config.online_config.rank_list:
- dispatcher.update_consume_queue(api_data)
-
- if config.online_config.nfs_path:
- shared_storage_communication_flow()
- else:
- tcp_communication_flow()
-
-
def blacklist_and_whitelist_filter(api_name, black_list, white_list):
"""
run api(api_name) if api_name not in black_list and in white_list.
@@ -315,21 +254,6 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
-def run_torch_api_online(api_full_name, api_data, backward_content):
- in_fwd_data_list = []
- api_type, api_name = extract_basic_api_segments(api_full_name)
- args, kwargs, out = api_data.args, api_data.kwargs, api_data.result
- in_fwd_data_list.append(args)
- in_fwd_data_list.append(kwargs)
- if kwargs.get("device"):
- del kwargs["device"]
-
- device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None)
- device_out = exec_api(device_exec_params)
- device_out = move2device_exec(device_out, "cpu")
- return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
-
-
def check_need_grad(api_info_dict):
need_grad = True
if api_info_dict.get(Const.INPUT_KWARGS) and "out" in api_info_dict.get(Const.INPUT_KWARGS):
@@ -389,16 +313,6 @@ def initialize_save_error_data(error_data_path):
return error_data_path
-def init_attl(config):
- """config: OnlineConfig"""
- attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True,
- connect_ip=config.host,
- connect_port=config.port,
- nfs_path=config.nfs_path,
- tls_path=config.tls_path))
- return attl
-
-
def _run_ut_parser(parser):
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
help=" The api param tool result file: generate from api param tool, "
@@ -481,38 +395,6 @@ def _run_ut(parser=None):
_run_ut_parser(parser)
args = parser.parse_args(sys.argv[1:])
run_ut_command(args)
-
-
-def checked_online_config(online_config):
- if not online_config.is_online:
- return
- if not isinstance(online_config.is_online, bool):
- raise ValueError("is_online must be bool type")
- # rank_list
- if not isinstance(online_config.rank_list, list):
- raise ValueError("rank_list must be a list")
- if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list):
- raise ValueError("All elements in rank_list must be integers")
-
- # nfs_path
- if online_config.nfs_path:
- check_file_or_directory_path(online_config.nfs_path, isdir=True)
- return
- # tls_path
- if online_config.tls_path:
- check_file_or_directory_path(online_config.tls_path, isdir=True)
- check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
- check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
- check_file_or_directory_path(os.path.join(online_config.tls_path, "ca.crt"))
- crl_path = os.path.join(online_config.tls_path, "crl.pem")
- if os.path.exists(crl_path):
- check_file_or_directory_path(crl_path)
-
- # host and port
- if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
- raise Exception(f"host: {online_config.host} is invalid.")
- if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535):
- raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.")
def run_ut_command(args):
@@ -525,7 +407,7 @@ def run_ut_command(args):
else:
checker_config = CheckerConfig()
- if not checker_config.is_online and not args.api_info_file:
+ if not args.api_info_file:
logger.error("Please provide api_info_file for offline run ut.")
raise Exception("Please provide api_info_file for offline run ut.")
@@ -588,8 +470,6 @@ def run_ut_command(args):
global UT_ERROR_DATA_DIR
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
error_data_path = initialize_save_error_data(error_data_path)
- online_config = checker_config.get_online_config()
- checked_online_config(online_config)
config_params = {
'forward_content': forward_content,
'backward_content': backward_content,
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py
deleted file mode 100644
index 2cfc355ec035d245261ca9c817e02687c684d471..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py
+++ /dev/null
@@ -1,205 +0,0 @@
-# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import glob
-import os.path
-import time
-from multiprocessing import Queue
-from typing import Optional, Union, Dict, Any
-from dataclasses import dataclass
-
-import torch
-
-from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
-from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
-from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
-from msprobe.core.common.file_utils import remove_path
-from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
-from msprobe.core.common.decorator import recursion_depth_decorator
-
-BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
-
-
-@dataclass
-class ATTLConfig:
- is_benchmark_device: bool
- connect_ip: str
- connect_port: int
- # storage_config
- nfs_path: str = None
- tls_path: str = None
- check_sum: bool = True
- queue_size: int = 50
-
-
-class ATTL:
- def __init__(self, session_id: str, session_config: ATTLConfig, need_dump=True) -> None:
- self.session_id = session_id
- self.session_config = session_config
- self.logger = logger
- self.socket_manager = None
- self.data_queue = Queue(maxsize=50)
- self.dequeue_list = []
- self.message_end = False
- self.kill_progress = False
- self.nfs_path = None
- if self.session_config.nfs_path:
- self.nfs_path = self.session_config.nfs_path
- elif self.session_config.is_benchmark_device:
-
- self.socket_manager = TCPServer(self.session_config.connect_port,
- self.data_queue,
- self.session_config.check_sum,
- self.session_config.tls_path)
- self.socket_manager.start()
- elif need_dump:
- self.socket_manager = TCPClient(self.session_config.connect_ip,
- self.session_config.connect_port,
- self.session_config.check_sum,
- self.session_config.tls_path)
- self.socket_manager.start()
-
- def stop_serve(self):
- if isinstance(self.socket_manager, TCPServer):
- self.socket_manager.stop()
-
- def send(self, buffer: BufferType) -> None:
- """
- npu major in 'send' (client)
- """
-
- # if tcp connection lost,
- if self.socket_manager.signal_exit:
- raise ConnectionError(f"Failed to connect to {self.session_config.connect_ip}.")
-
- # know receiver receive and go next
- if isinstance(buffer, ApiData):
- buffer = move2target_device(buffer, torch.device('cpu'))
-
- if 'device' in buffer.kwargs:
- buffer.kwargs.pop('device')
- rank = buffer.rank if hasattr(buffer, "rank") and buffer.rank is not None else 0
- step = buffer.step if hasattr(buffer, "step") else 0
- try:
- io_buff = save_api_data(buffer)
- except Exception as e:
- self.logger.info(f"{buffer.name} can not be saved, skip: {e}")
- return
- data = io_buff.getvalue()
- self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
-
- def recv(self, timeout_ms=0) -> Optional[BufferType]:
- buffer = ''
- while not buffer:
- if timeout_ms > 0:
- time.sleep(timeout_ms / 1000.0)
- if not buffer and not self.data_queue.empty():
- buffer = self.data_queue.get()
- break
- if not buffer and timeout_ms > 0: # timeout is the only case we give up and return None
- break
- if self.message_end and self.data_queue.empty():
- buffer = b"KILL_CONFIRM"
- self.kill_progress = True
- break
- time.sleep(0.1) # waiting outside the lock before next attempt
- if not buffer:
- # this is a result of a timeout
- self.logger.info(f"RECEIVE API DATA TIMED OUT")
- else:
- if buffer == b"STOP_":
- return "STOP_"
- if buffer == b"KILL_":
- self.message_end = True
- return "STOP_"
- if buffer == b"KILL_CONFIRM":
- self.kill_progress = True
- return "KILL_"
- try:
- buffer = load_api_data(buffer)
- except Exception as e:
- self.logger.warning("there is something error. please check it. %s", e)
- if isinstance(buffer, bytes):
- return ''
- if isinstance(buffer, str):
- return buffer
-
- return buffer
-
- def upload(self, buffer: BufferType):
- if isinstance(buffer, ApiData):
- buffer = move2target_device(buffer, torch.device('cpu'))
- file_path = os.path.join(self.session_config.nfs_path, buffer.name + ".pt")
- else:
- file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
-
- try:
- save_pkl(buffer, file_path)
- except Exception as e:
- self.logger.warning("there is something error in save_pt. please check it. %s", e)
-
- def download(self):
- buffer = None
- cur_file = None
- for file_type in ("start*", "*.pt", "end*"):
- pattern = os.path.join(self.nfs_path, file_type)
- files = glob.glob(pattern)
- if len(files) > 0:
- cur_file = files[0]
- break
-
- if cur_file is not None:
- try:
- buffer = load_pkl(cur_file)
- except Exception as e:
- self.logger.warning("there is something error. please check it. %s", e)
- remove_path(cur_file)
- return buffer
-
-
-@recursion_depth_decorator("move2device_exec")
-def move2device_exec(obj, device):
- if isinstance(obj, (tuple, list)):
- data_list = [move2device_exec(val, device) for val in obj]
- return data_list if isinstance(obj, list) else tuple(data_list)
- if isinstance(obj, dict):
- return {key: move2device_exec(val, device) for key, val in obj.items()}
- elif isinstance(obj, torch.Tensor):
- obj = obj.detach()
- if obj.device.type != device:
- obj = obj.to(device)
- return obj
- elif "return_types" in str(type(obj)):
- return move2device_exec(tuple(obj), device)
- elif isinstance(obj, torch._C.device):
- return torch.device(device)
- else:
- return obj
-
-
-def move2target_device(buffer: ApiData, target_device):
- # handle args
- new_args = move2device_exec(buffer.args, target_device)
-
- # handle kwargs
- new_kwargs = move2device_exec(buffer.kwargs, target_device)
-
- # handle result
- new_results = move2device_exec(buffer.result, target_device)
-
- if target_device == torch.device('cpu') or target_device == "cpu":
- return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results, buffer.step, buffer.rank)
- else:
- return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank)
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py
deleted file mode 100644
index a55ecae283105ed3d3127b862fc817ca371732db..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py
+++ /dev/null
@@ -1,378 +0,0 @@
-# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from functools import partial
-import zlib
-import io
-import struct
-import time
-import os
-from queue import Queue
-from threading import Thread
-from typing import Union
-
-from twisted.internet import reactor, protocol, endpoints, ssl
-from twisted.protocols.basic import FileSender
-
-from msprobe.pytorch.common.utils import logger
-from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import STRUCT_UNPACK_MODE as unpack_mode, \
- STR_TO_BYTES_ORDER as bytes_order, cipher_list, verify_callback, load_ssl_pem
-
-MAX_SENDING_QUEUE_SIZE = 20
-
-
-class TCPDataItem:
- def __init__(self, data,
- sequence_number: int,
- rank: int = 0,
- step: int = 0):
- self.raw_data = data
- self.sequence_number = sequence_number
- self.rank = rank
- self.step = step
- self.retry_times = 0
- self.pending_time = 0
- self.busy_time = 0
-
-
-class TCPClient:
- ACK_SUCCESS = b"OK___"
- ACK_ERROR = b"ERROR"
- ACK_BUSY = b"BUSY_"
- ACK_STOP = b"STOP_"
- ACK_STOP_CONFIRM = b"OVER_"
- ACK_KILL_PROCESS = b"KILL_"
-
- QUEUE_PENDING_TIME = 60
- RESEND_RETRY_TIMES = 2 # 最大重传数
- RESEND_TIMER_TIME = 5 # 接收ACK超时定时器
- RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据
-
- def __init__(self, host="localhost", port=8000, check_sum=False, tls_path=None):
- self.send_queue = Queue(MAX_SENDING_QUEUE_SIZE)
- self.resend_dict = dict()
- self.host = host
- self.port = port
- self.tls_path = tls_path
- self.factory = None
- self.sequence_number = 0
- self.signal_exit = False
- self.tcp_manager = ClientProtocol(ack_queue_size=100,
- chunk_size=655360,
- check_sum=check_sum,
- tls=self.tls_path)
- self.send_thread = Thread(target=self._sending_queue_data)
- self.send_thread.setDaemon(True)
- self.send_thread.start()
- self.destroy_thread = Thread(target=self._destroy_queue_data)
- self.destroy_thread.setDaemon(True)
- self.destroy_thread.start()
-
- @staticmethod
- def run_reactor():
- reactor.run(installSignalHandlers=False)
-
- def start(self):
- def conn_callback(cur_protocol):
- if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host:
- logger.debug(f"Process: {os.getpid()} connects to server successfully.")
- else:
- logger.warning(f"Process: {os.getpid()} fails to connect to server. ")
- raise ConnectionError(f"Failed to connect to {self.host}.")
-
- def conn_err_callback(failure):
- self.signal_exit = True
- time.sleep(1)
- reactor.stop()
- logger.error(f"Failed to connected {self.host} {self.port}. Reason is {failure.getErrorMessage()}")
-
- def cur_protocol():
- return self.tcp_manager
-
- self.factory = MessageClientFactory()
- self.factory.protocol = cur_protocol
- if self.tls_path:
- client_key, client_crt, ca_crt, crl_pem = load_ssl_pem(
- key_file=os.path.join(self.tls_path, "client.key"),
- cert_file=os.path.join(self.tls_path, "client.crt"),
- ca_file=os.path.join(self.tls_path, "ca.crt"),
- crl_file=os.path.join(self.tls_path, "crl.pem")
- )
-
- ssl_options = ssl.CertificateOptions(
- privateKey=client_key,
- certificate=client_crt,
- method=ssl.SSL.TLSv1_2_METHOD,
- verify=True,
- requireCertificate=True,
- caCerts=[ca_crt], # 信任的CA证书列表
- )
- ssl_context = ssl_options.getContext()
- ssl_context.set_cipher_list(cipher_list)
- ssl_context.set_options(ssl.SSL.OP_NO_RENEGOTIATION)
- ssl_context.set_verify(ssl.SSL.VERIFY_PEER | ssl.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
- partial(verify_callback, crl=crl_pem))
-
- endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, ssl_options)
- else:
- endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port)
- d = endpoint.connect(self.factory)
- d.addCallback(conn_callback)
- d.addErrback(conn_err_callback)
-
- reactor_thread = Thread(target=self.run_reactor, daemon=True)
- reactor_thread.start()
-
- def send_after_queue_empty(self, data):
- while not self._ready_to_exit():
- if not self.tls_path:
- self.add_to_sending_queue(data)
- else:
- for _ in range(MAX_SENDING_QUEUE_SIZE):
- self.add_to_sending_queue(data)
- time.sleep(2)
-
- def check_client_alive(self):
- return self.factory.num_connections > 0
-
- def stop(self):
- self.tcp_manager.connection_timeout()
-
- def send_stop_signal(self):
- self.send_after_queue_empty(self.ACK_STOP)
- while not self._ready_to_exit():
- if not self.check_client_alive():
- break
- time.sleep(1)
-
- def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], rank: int = 0, step: int = 0):
- if self._ready_to_exit():
- return
-
- send_data = data
- if not isinstance(data, TCPDataItem):
- send_data = TCPDataItem(data=data,
- sequence_number=self.sequence_number,
- rank=rank,
- step=step)
- self.sequence_number += 1
- try:
- self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME)
- except Exception as e:
- logger.error(f"send_queue put send_data timeout, rank: {send_data.rank}, step: {send_data.step},"
- f"sequence_number: {send_data.sequence_number}, send_queue size: {self.send_queue.qsize()},"
- f"{str(e)}")
-
- def _send_data(self, data: TCPDataItem):
- self.tcp_manager.send_wrapped_data(data.raw_data,
- sequence_number=data.sequence_number,
- rank=data.rank,
- step=data.step
- )
-
- def _sending_queue_data(self):
- while True:
- if not self.tcp_manager.is_connected:
- continue
-
- while self.send_queue.qsize() > 0:
- if self._ready_to_exit():
- break
- if len(self.resend_dict) < MAX_SENDING_QUEUE_SIZE:
- data_obj = self.send_queue.get()
- resend_key = str(data_obj.sequence_number) + "_" + str(data_obj.rank) + "_" + str(data_obj.step)
- logger.debug(f"get {resend_key} from send_queue, and send to server.")
- self._send_data(data_obj)
- if resend_key not in self.resend_dict.keys():
- # Send data for the first time
- self.resend_dict[resend_key] = data_obj
- else:
- time.sleep(0.1)
-
- if self._ready_to_exit():
- logger.debug("Successfully close sending process.")
- break
- time.sleep(0.1)
-
- def _destroy_queue_data(self):
- while True:
- if self._ready_to_exit():
- break
-
- while len(self.resend_dict) > 0 and self.tcp_manager.ack_queue.qsize() > 0:
- ack_info, seq_number, rank, step = self.tcp_manager.ack_queue.get()
- obj_key = str(seq_number) + "_" + str(rank) + "_" + str(step)
- current_item = self.resend_dict.get(obj_key)
-
- if current_item is None:
- continue
-
- if ack_info == self.ACK_SUCCESS:
- self.resend_dict.pop(obj_key)
- elif ack_info == self.ACK_BUSY:
- logger.debug("RECV BUSY ACK")
- if current_item.busy_time > 5:
- self._resend_data(current_item)
- else:
- current_item.busy_time += 1
- elif ack_info == self.ACK_ERROR:
- logger.debug("RECV ERROR ACK")
- self._resend_data(current_item)
- elif ack_info == self.ACK_STOP_CONFIRM:
- logger.debug("RECV STOP ACK")
- self.factory.num_connections -= 1
-
- break
-
- time.sleep(0.1)
-
- def _resend_data(self, data: TCPDataItem):
- if data.retry_times < self.RESEND_RETRY_TIMES:
- data.retry_times += 1
- logger.debug(f"Resend data seq number: {data.sequence_number}")
- self.add_to_sending_queue(data)
- else:
- self.resend_dict.pop(data.sequence_number)
- logger.debug(f"SKIP send sequence number {data.sequence_number} after retry {data.retry_times} times!")
-
- def _pending_data(self, data: TCPDataItem):
- if data.pending_time >= self.RESEND_PENDING_TIME:
- self.resend_dict.pop(data.sequence_number)
- logger.debug(f"SKIP send sequence number {data.sequence_number} after pending {data.pending_time} times!")
- return
-
- # wait time is 100MB per second
- pending_time = max(1, len(data.raw_data) // (2 ** 20 * 50))
- data.pending_time += pending_time
- time.sleep(pending_time)
-
- def _ready_to_exit(self):
- return self.signal_exit or self.tcp_manager.signal_exit
-
-
-class ClientProtocol(protocol.Protocol):
- TIMEOUT = 60 * 10
-
- def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False, tls=None):
- self.buffer = io.BytesIO()
- self.is_connected = False
- self.check_sum = check_sum
- self.tell = 0
- self.ack_queue = Queue(maxsize=ack_queue_size)
- self.file_sender = FileSender()
- self.file_sender.CHUNK_SIZE = chunk_size
- self.signal_exit = False
- self.defer = None
- self.kill_process = False
- self.ack = None
-
- self.timeout_call = None
-
- self.tls = tls
- self.send_buffer = b""
- self.buffer_cnt = 0
-
- def dataReceived(self, data):
- if self.timeout_call.active():
- self.timeout_call.reset(self.TIMEOUT)
-
- self.buffer.seek(0, 2)
- self.buffer.write(data)
- self.buffer.seek(self.tell)
- while True:
- if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3
- ack = self.buffer.read(5)
- self.ack = ack
- seq_number = struct.unpack(unpack_mode, self.buffer.read(8))[0]
- rank = struct.unpack(unpack_mode, self.buffer.read(8))[0]
- step = struct.unpack(unpack_mode, self.buffer.read(8))[0]
- logger.debug(f"receive 流水号: {seq_number}; RANK: {rank}; STEP: {step}; ACK: {ack}")
- if ack == b"KILL_":
- self.kill_process = True
- logger.debug(f"接收到KILL信号, PID {os.getpid()}")
- if ack == b"OVER_":
- self.factory.num_connections -= 1
- self.tell += 29
- if not self.ack_queue.full():
- self.ack_queue.put((ack, seq_number, rank, step))
- self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:])
- self.tell = 0
- else:
- time.sleep(0.1)
- else:
- break
-
- def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0):
- length = len(data)
- data_crc = f"{zlib.crc32(data):08x}" if self.check_sum else ""
- data_meaasge = length.to_bytes(8, byteorder=bytes_order) + \
- sequence_number.to_bytes(8, byteorder=bytes_order) + \
- rank.to_bytes(8, byteorder=bytes_order) + \
- step.to_bytes(8, byteorder=bytes_order) + \
- data_crc.encode() + \
- data
- logger.debug(f"send 流水号: {sequence_number}; RANK: {rank}; STEP: {step}; LENGTH: {length}")
-
- while True:
- if self.defer is None or self.defer.called:
- self.defer = self.send_large_data(data_meaasge)
- break
- time.sleep(0.01)
-
- def send_large_data(self, data):
-
- if self.tls:
- self.send_buffer += data
- self.buffer_cnt += 1
- if self.buffer_cnt >= MAX_SENDING_QUEUE_SIZE:
- d = self.file_sender.beginFileTransfer(io.BytesIO(self.send_buffer), self.transport)
- self.send_buffer = b""
- self.buffer_cnt = 0
- else:
- d = None
- else:
- d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport)
- return d
-
- def connection_timeout(self):
- if self.factory.num_connections <= 0:
- return
-
- self.factory.num_connections -= 1
- logger.debug(f"超时退出{self.transport.addr}, PID {os.getpid()}")
- self.transport.loseConnection()
-
- def connectionMade(self):
- self.timeout_call = reactor.callLater(self.TIMEOUT, self.connection_timeout)
- self.is_connected = True
- self.factory.num_connections += 1
- logger.info("successfully connect server")
-
- def connectionLost(self, reason):
- self.signal_exit = True
- self.factory.num_connections -= 1
- logger.info(f"Lost connection with server, reason is : {reason.value}")
-
-
-class MessageClientFactory(protocol.ClientFactory):
- def __init__(self):
- self.num_connections = 0
-
- def clientConnectionFailed(self, connector, reason):
- logger.info(f"Fail to connection with server: {reason.getErrorMessage()}")
- reactor.stop()
-
- def clientConnectionLost(self, connector, reason):
- logger.info(f"Client lost connection with server: {reason.getErrorMessage()}")
- reactor.stop()
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py
deleted file mode 100644
index 6fc36bcdecac81ae302ec9fd64079758f74e4071..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py
+++ /dev/null
@@ -1,239 +0,0 @@
-# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import time
-from collections import namedtuple
-
-import pandas as pd
-import torch
-import torch.multiprocessing as mp
-
-from msprobe.core.common.const import Const, CompareConst
-from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import online_api_precision_compare
-from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import DETAIL_TEST_ROWS, thousandth_standard_api, \
- binary_standard_api, absolute_standard_api
-from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api, ExecParams
-from msprobe.pytorch.common.log import logger
-from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device
-from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params
-
-# NPU vs GPU api list
-CompareApi = set(absolute_standard_api) | set(binary_standard_api) | set(thousandth_standard_api)
-
-current_time = time.strftime("%Y%m%d%H%M%S")
-ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME = "api_precision_compare_result_" + current_time + "_rank*.csv"
-ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME = "api_precision_compare_details_" + current_time + "_rank*.csv"
-
-OnlineApiPrecisionCompareConfig = namedtuple('OnlineApiPrecisionCompareConfig',
- ['npu_data', 'gpu_data', 'rank', 'result_csv_path', 'details_csv_path'])
-# namedtuple of [instance of Comparator, func of run_touch_api_online, config of run_ut_config]
-CommonCompareConfig = namedtuple('CommonCompareConfig', ['compare', 'handle_func', 'config'])
-
-
-def get_gpu_device():
- is_gpu = False
- try:
- import torch_npu
- except ImportError:
- is_gpu = True
- return is_gpu
-
-
-def run_ut_process(xpu_id, consumer_queue, common_config, api_precision_csv_file):
- """ When consumer_queue(shared with ConsumerDispatcher) is not empty, consume api data from consumer_queue.
- :param xpu_id: int
- :param consumer_queue: shared queues of ConsumerDispatcher
- :param common_config: namedtuple of CommonCompareConfig
- :param api_precision_csv_file: list, length is 2, result file name and details file name
- :return:
- """
- device_info = "cuda" if get_gpu_device() else "npu"
- logger.info(f"Start run_ut_process for {device_info} device, rank: {xpu_id}.")
- gpu_device = torch.device(f'{device_info}:{xpu_id}')
-
- while True:
- if consumer_queue.empty():
- time.sleep(0.1)
- continue
-
- api_data = consumer_queue.get()
- if api_data == "KILL_":
- # current consumer finish
- return
-
- _, api_name, _ = api_data.name.split(Const.SEP)
- if api_name in CompareApi:
- # NPU vs GPU
- online_compare(api_data, gpu_device, common_config)
- else:
- # NPUvsCPU vs GPUvsCPU
- online_precision_compare(api_data, gpu_device, common_config, api_precision_csv_file)
-
-
-def online_precision_compare(api_data, device, common_config, api_precision_csv_file):
- """online run_ut for precision_compare: NPUvsCPU vs GPUvsCPU
- 1. get NPUvsCPU compare result
- 2. get GPUvsCPU compare result
- 3. call online_api_precision_compare
- :param api_data
- :param device
- :param common_config: namedtuple of CommonCompareConfig
- :param api_precision_csv_file: [result_file_name, details_file_name]
- """
- compare, func, config = common_config.compare, common_config.handle_func, common_config.config
- api_full_name = api_data.name
- [api_type, api_name, _] = api_full_name.split(Const.SEP)
- npu_args, npu_kwargs, npu_out = api_data.args, api_data.kwargs, api_data.result
-
- if npu_kwargs.get("device"):
- del npu_kwargs["device"]
-
- try:
- # NPU vs CPU
- cpu_params = generate_cpu_params(npu_args, npu_kwargs, False, api_name)
- cpu_args, cpu_kwargs = cpu_params.cpu_args, cpu_params.cpu_kwargs
- cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs, False, None)
- cpu_out = exec_api(cpu_exec_params)
- npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank)
- npu_detail = compare.compare_output(api_full_name, npu_data_info, True)
- npu_data = pd.DataFrame(npu_detail, columns=DETAIL_TEST_ROWS[-1])
-
- # GPU vs CPU
- api_data_gpu = move2target_device(api_data, device) # args, kwargs -> gpu, result -> npu
- data_info = func(api_full_name, api_data_gpu, config.backward_content)
- gpu_out = data_info.bench_output
- gpu_data_info = UtDataInfo(None, None, gpu_out, cpu_out, None, [], None, rank=api_data.rank)
- gpu_detail = compare.compare_output(api_full_name, gpu_data_info, True)
- gpu_data = pd.DataFrame(gpu_detail, columns=DETAIL_TEST_ROWS[-1])
-
- # NPUvsCPU vs GPUvsCPU
- result_file_name, details_file_name = api_precision_csv_file
- precision_compare_config = OnlineApiPrecisionCompareConfig(npu_data, gpu_data, api_data.rank,
- result_file_name, details_file_name)
- online_api_precision_compare(precision_compare_config)
-
- except Exception as err:
- if "expected scalar type Long" in str(err):
- logger.warning(
- f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
- f"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
- elif api_type in [Const.DISTRIBUTED]:
- logger.info(f"{api_full_name} is not supported for run ut. SKIP.")
- else:
- logger.error(f"Run {api_full_name} UT Error: {str(err)}")
-
- compare.write_summary_csv((api_full_name, CompareConst.SKIP, CompareConst.SKIP, [[str(err)]], api_data.rank))
-
- finally:
- torch.cuda.empty_cache()
-
-
-def online_compare(api_data, device, common_config):
- """online run_ut for compare:NPU vs GPU
- """
- compare, func, config = common_config.compare, common_config.handle_func, common_config.config
- api_full_name = api_data.name
- api_data = move2target_device(api_data, device)
- try:
- data_info = func(api_full_name, api_data, config.backward_content)
- is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info)
- logger.info(f"running api_full_name {api_full_name} ut, "
- f"is_fwd_success: {is_fwd_success}, "
- f"is_bwd_success: {is_bwd_success}")
- except Exception as err:
- [api_type, api_name, _] = api_full_name.split(Const.SEP)
- if "expected scalar type Long" in str(err):
- logger.warning(
- f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
- f"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
- elif api_type in [Const.DISTRIBUTED]:
- logger.info(f"{api_full_name} is not supported for run ut. SKIP.")
- else:
- logger.error(f"Run {api_full_name} UT Error: {str(err)}")
-
- compare.write_summary_csv((api_full_name, CompareConst.SKIP, CompareConst.SKIP, [[str(err)]], api_data.rank))
-
- finally:
- torch.cuda.empty_cache()
-
-
-class ConsumerDispatcher:
- def __init__(self, compare, capacity=10, num_workers=8, device: str = "gpu") -> None:
- self.num_workers = num_workers
- self.capacity = capacity
- self.compare = compare
- self.queues = []
- self.processes = []
- self.reverse_sort = False
- self.pool = None
- self.device = device
- self.data_id = 0
- self.lock = mp.Lock()
- self.result_queue = mp.Queue()
- mp.set_start_method("spawn", force=True)
-
- def start(self, handle_func, config):
- self.queues = [mp.Queue(maxsize=self.capacity) for _ in range(self.num_workers)]
- api_precision_csv_file = [
- ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME,
- ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME
- ]
- common_config = CommonCompareConfig(self.compare, handle_func, config)
- for xpu_id, q in enumerate(self.queues):
- p = mp.Process(name="run_ut_process", target=run_ut_process,
- args=(xpu_id, q, common_config, api_precision_csv_file))
-
- p.start()
- self.processes.append(p)
- logger.info(
- f'Api_precision_compare task result will be saved in {ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME}')
- logger.info(
- f"Api_precision_compare task details will be saved in {ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME}")
- logger.info("Successfully start unittest process.")
-
- def stop(self):
- for q in self.queues:
- while q.full():
- time.sleep(0.1)
- q.put("KILL_")
-
- for p in self.processes:
- p.join()
- logger.info("Successfully stop unittest process.")
- logger.info(f"Api_precision_compare task result is saved in {ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME}")
- logger.info(f"Api_precision_compare task details is saved in {ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME}")
-
- def update_consume_queue(self, api_data):
- while True:
- index = self._choose_max_empty_site_strategy()
- if index != -1:
- q = self.queues[index]
- q.put(api_data)
- break
- time.sleep(0.1)
-
- def _choose_max_empty_site_strategy(self):
- maximum = 0
- index = -1
- # 充分利用多卡资源,防止任务过多分配给前面的卡
- _reverse = 1 if not self.reverse_sort else -1
- for i, q in enumerate(self.queues[::_reverse]):
- empty_site = self.capacity - q.qsize()
- if empty_site > maximum:
- maximum = empty_site
- index = i
- index = len(self.queues) - index - 1 if index != -1 and self.reverse_sort else index
- self.reverse_sort = not self.reverse_sort
- return index
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py
deleted file mode 100644
index 61650705e48056d5964c7ba48ff442247a08e4f9..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py
+++ /dev/null
@@ -1,115 +0,0 @@
-
-# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-from collections import defaultdict
-from functools import wraps
-
-import torch
-from torch.utils._python_dispatch import TorchDispatchMode
-from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
-from msprobe.pytorch.common.utils import get_tensor_rank
-from msprobe.core.common.const import Const
-from msprobe.pytorch.common.log import logger
-from msprobe.core.common.file_utils import load_yaml
-
-
-def singleton(cls):
- _instance = {}
-
- @wraps(cls)
- def inner():
- if cls not in _instance:
- _instance[cls] = cls()
- return _instance[cls]
- return inner
-
-
-@singleton
-class Counter:
- def __init__(self) -> None:
- self.index_dict = defaultdict(int)
-
-
-counter = Counter()
-yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
-yaml_file = load_yaml(yaml_path)
-
-
-class AccuracyCheckerDispatch(TorchDispatchMode):
- def __init__(self, attl):
- super(AccuracyCheckerDispatch, self).__init__()
- self.attl = attl
- self.counter = counter
- self.aten_ops_blacklist = []
- self.npu_adjust_autogard = []
- self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist', [])
- self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard', [])
-
- def __torch_dispatch__(self, func, types, args=None, kwargs=None):
- func_name_split_list = func.__name__.split(Const.SEP)
- aten_api = func_name_split_list[0]
- self.enable_autogard(aten_api)
- if aten_api in self.aten_ops_blacklist:
- npu_out = func(*args, **kwargs)
- return npu_out
-
- res = func(*args, **kwargs)
- cur_rank = get_tensor_rank(args, res)
- cur_api_number = self.counter.index_dict[aten_api]
- api_name = f'{Const.ATEN}{Const.SEP}{aten_api}{Const.SEP}{cur_api_number}'
- logger.info(f"tools is dumping api: {api_name}, rank: {cur_rank}")
- api_data = ApiData(api_name, args, kwargs, res, 0, cur_rank)
- if "device" in api_data.kwargs:
- api_data.kwargs.pop("device")
- if self.attl.nfs_path:
- self.attl.upload(api_data)
- else:
- self.attl.send(api_data)
- self.counter.index_dict[aten_api] += 1
-
- return res
-
- def enable_autogard(self, aten_api):
- if aten_api in self.npu_adjust_autogard:
- torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False)
-
-
-def dispatch4data(func, attl, status):
- @wraps(func)
- def wrapper(*args, **kwargs):
- if not status:
- return func(*args, **kwargs)
- with AccuracyCheckerDispatch(attl):
- res = func(*args, **kwargs)
- return res
-
- return wrapper
-
-
-def run_ut_dispatch(attl, status, is_recompute=False):
- """
- This function called by online_run_ut.
- It is used to enable or disable dispatch for torch.autograd.backward function.
-
- Args:
- attl (ATTL): online_run_ut class ATTL, which is used to upload or send api data to server.
- status (bool): True means enable dispatch, False means disable dispatch.
- is_recompute (bool): Flag of recompute, which is conflicted with aten api, then skip dispatch4data.
- """
- if is_recompute:
- return
- torch.autograd.backward = dispatch4data(torch.autograd.backward, attl, status)
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py
deleted file mode 100644
index d51138941c7711e404d561e0c92389de581c3b3c..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py
+++ /dev/null
@@ -1,250 +0,0 @@
-# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from functools import partial
-import os
-import struct
-import zlib
-import time
-import io
-from threading import Thread
-
-from twisted.internet import reactor, protocol, endpoints, ssl
-
-from msprobe.pytorch.common.utils import logger
-from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import cipher_list, \
- STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order, verify_callback, load_ssl_pem
-
-
-class TCPServer:
- def __init__(self, port, shared_queue, check_sum=False, tls_path=None) -> None:
- self.port = port
- self.shared_queue = shared_queue
- self.check_sum = check_sum
- self.tls_path = tls_path
- self.factory = MessageServerFactory()
- self.reactor_thread = None
-
- @staticmethod
- def run_reactor():
- reactor.run(installSignalHandlers=False)
-
- def start(self):
- self.factory.protocol = self.build_protocol
-
- if self.tls_path:
- server_key, server_crt, ca_crt, crl_pem = load_ssl_pem(
- key_file=os.path.join(self.tls_path, "server.key"),
- cert_file=os.path.join(self.tls_path, "server.crt"),
- ca_file=os.path.join(self.tls_path, "ca.crt"),
- crl_file=os.path.join(self.tls_path, "crl.pem")
- )
-
- ssl_options = ssl.CertificateOptions(
- privateKey=server_key,
- certificate=server_crt,
- method=ssl.SSL.TLSv1_2_METHOD,
- verify=True,
- requireCertificate=True,
- caCerts=[ca_crt], # 信任的CA证书列表
- )
- ssl_context = ssl_options.getContext()
- ssl_context.set_cipher_list(cipher_list)
- ssl_context.set_options(ssl.SSL.OP_NO_RENEGOTIATION)
- ssl_context.set_verify(ssl.SSL.VERIFY_PEER | ssl.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
- partial(verify_callback, crl=crl_pem))
-
- endpoint = endpoints.SSL4ServerEndpoint(reactor, self.port, ssl_options)
- else:
- endpoint = endpoints.TCP4ServerEndpoint(reactor, self.port)
- endpoint.listen(self.factory)
- self.reactor_thread = Thread(target=self.run_reactor, daemon=True)
- self.reactor_thread.start()
-
- def is_running(self):
- return not self.factory.is_all_connection_closed()
-
- def stop(self):
- self.factory.doStop()
- reactor.callFromThread(reactor.sigInt, 2)
- self.reactor_thread.join()
-
- def build_protocol(self):
- return ServerProtocol(self.shared_queue, self.check_sum)
-
-
-class ServerProtocol(protocol.Protocol):
- ACK_SUCCESS = b"OK___"
- ACK_ERROR = b"ERROR"
- ACK_BUSY = b"BUSY_"
- ACK_STOP = b"STOP_"
- ACK_STOP_CONFIRM = b"OVER_"
- ACK_KILL_PROCESS = b"KILL_"
-
- def __init__(self, shared_queue, check_sum=False):
- self.start_time = None
- self.buffer = io.BytesIO()
- self.consumer_queue = shared_queue
- self.check_sum = check_sum
- self.length_width = 8
- self.crc_width = 8
- self.obj_length = None
- self.tell = 0
- self.obj_crc = None
- self.obj_body = None
- self.sequence_number = -1
- self.rank = -1
- self.step = -1
- self.sequence_number_dict = dict()
-
- def connectionMade(self):
- self.buffer = io.BytesIO()
- self.obj_length = None
- self.tell = 0
- self.obj_crc = None
- self.obj_body = None
- self.factory.transport_dict[self.transport] = 1
- self.factory.transport_list.append(self.transport)
- logger.info(f"Connected to {self.transport.getPeer()} successfully.")
-
- def connectionLost(self, reason):
- self.factory.transport_dict.pop(self.transport, None)
- if len(self.factory.transport_dict) == 0:
- self.consumer_queue.put(self.ACK_KILL_PROCESS)
-
- logger.info(f"Lost connection with {self.transport.getPeer()}. Reason is: {reason} 与客户端 断开连接, "
- f"current connection number is: {len(self.factory.transport_dict)}")
-
- def send_ack(self, ack_info):
- ack_message = b"".join([
- ack_info,
- self.sequence_number.to_bytes(8, byteorder=bytes_order),
- self.rank.to_bytes(8, byteorder=bytes_order),
- self.step.to_bytes(8, byteorder=bytes_order)
- ])
- self.transport.write(ack_message)
-
- def post_process(self):
- send_busy_ack = False
- while self.consumer_queue.full():
- if not send_busy_ack:
- self.send_ack(self.ACK_BUSY)
- logger.debug("sending BUSY ACK")
- send_busy_ack = True
- time.sleep(0.1)
-
- obj_key = str(self.sequence_number) + "_" + str(self.rank) + "_" + str(self.step)
- # get the crc value of a 16-bit string with a length of 8
- recv_crc = f"{zlib.crc32(self.obj_body):08x}"
-
- if self.check_sum and recv_crc != self.obj_crc:
- # when needs check hash value and check no pass, indicates received data error, send b"ERROR" to client.
- logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_crc}, but get {recv_crc}")
- self.send_ack(self.ACK_ERROR)
- else:
- if self.obj_body == self.ACK_STOP:
- self.handle_with_stop()
- else:
- self.send_ack(self.ACK_SUCCESS)
- if obj_key in self.sequence_number_dict:
- logger.debug(f"这是一次异常的重传,可以忽略。 {obj_key}, {self.sequence_number_dict}")
- else:
- self.sequence_number_dict[obj_key] = self.obj_crc
- self.consumer_queue.put(self.obj_body, block=True)
-
- self.reset_env()
- finish_time = time.time()
- logger.debug(f"finish_time: {finish_time - self.start_time}")
-
- def handle_with_stop(self):
- logger.debug(f"接收到停止传输信号 TCP{self.transport.getPeer()}")
- self.send_ack(self.ACK_STOP_CONFIRM)
- if len(self.factory.transport_dict) == 0:
- _rank, _step, _sequence_number = 0, 0, 100000000
- ack_kill = self.ACK_KILL_PROCESS + \
- _sequence_number.to_bytes(8, byteorder='big') + \
- _rank.to_bytes(8, byteorder='big') + \
- _step.to_bytes(8, byteorder='big')
- for trans in self.factory.transport_list:
- trans.write(ack_kill)
- logger.debug(f"发送KILL信息给{self.transport.getPeer()}")
- self.consumer_queue.put(self.ACK_KILL_PROCESS)
- time.sleep(2)
-
- def reset_env(self):
- self.obj_length = None
- self.sequence_number = -1
- self.rank = -1
- self.step = -1
- self.obj_crc = None
- self.obj_body = None
-
- def dataReceived(self, data):
- self.buffer.seek(0, 2)
- self.buffer.write(data)
- self.buffer.seek(self.tell)
-
- # The first data packet is packet header, it contains obj_length, sequence_number, rank, step
- if self.obj_length is None and len(self.buffer.getvalue()) >= self.length_width * 4:
- self.start_time = time.time()
- self.obj_length = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
- self.sequence_number = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
- self.rank = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
- self.step = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
- self.tell += self.length_width * 4
- logger.debug(
- f"流水号: {self.sequence_number}; RANK: {self.rank}; STEP: {self.step}; Length: {self.obj_length}")
-
- # If needs check hash but not parse crc yet, read 8b crc values
- check_sum_and_crc = (self.check_sum
- and self.obj_length is not None
- and self.obj_crc is None
- and len(self.buffer.getvalue()) - self.tell >= self.crc_width)
- if check_sum_and_crc:
- self.obj_crc = self.buffer.read(self.crc_width).decode()
- self.tell += self.crc_width
- logger.debug(f"Hash value: {self.obj_crc}")
-
- current_length = len(self.buffer.getvalue()) - self.tell
- if self.obj_length is not None and 0 < self.obj_length <= current_length:
- # Current api data receive finished
- self.obj_body = self.buffer.read(self.obj_length)
-
- self.tell += self.obj_length
- self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:])
- self.buffer.seek(0)
- self.tell = 0
- recv_data_time = time.time()
- logger.debug(f"self.sequence_number {self.sequence_number} "
- f"recv_data_time {recv_data_time - self.start_time}")
-
- if self.obj_body == self.ACK_STOP:
- # Indicates the current TCP link receives a STOP signal and remove from the transport_dict
- _transport = self.factory.transport_dict.pop(self.transport, None)
- logger.debug(f"接收到b'STOP_' self.sequence_number {self.sequence_number} ")
- self.post_process()
-
-
-class MessageServerFactory(protocol.ServerFactory):
- def __init__(self) -> None:
- """
- transport_dict: links that have not completed data transmission.
- transport_list: Records all TCP links. Appends TCP link to the transport list
- when a new TCP link is established.
- """
- self.transport_dict = {}
- self.transport_list = []
-
- def is_all_connection_closed(self):
- return len(self.transport_dict) == 0
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml
deleted file mode 100644
index 373e6ed0fc33a97537f38eedeb36a3e90122525a..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml
+++ /dev/null
@@ -1,63 +0,0 @@
-aten_ops_blacklist:
- - npu_binary_cross_entropy_with_logits_backward
- - npu_ciou_backward
- - _cudnn_rnn
- - _local_scalar_dense
- - _pin_memory
- - _to_copy
- - _unsafe_view
- - clone
- - contiguous
- - copy_
- - cudnn_batch_norm
- - cudnn_batch_norm_backward
- - detach
- - empty
- - index_put_
- - lift_fresh
- - max_pool2d_with_indices_backward # shape unmatch
- - native_batch_norm_backward
- - new_empty
- - new_empty_strided
- - new_full
- - new_ones
- - new_zeros
- - ones
- - ones_like
- - permute
- - rand
- - rand_like
- - randint
- - randint_like
- - randn
- - randn_like
- - randperm
- - scalar_tensor
- - select
- - to
- - transpose
- - unbind
- - view
- - zero
- - zero_
- - zeros
- - zeros_like
- - _record_function_enter_new
- - _record_function_exit
- - broadcast_
- - allreduce_
- - npu_clear_float_status
- - npu_format_cast
- - npu_dtype_cast
- - npu_dtype_cast_backward
- - _allgather_base_
- - _reduce_scatter_base_
- - is_same_size
-
-npu_adjust_autogard:
- - adaptive_avg_pool2d
- - batch_norm
- - log_softmax
- - nll_loss
- - to
-
\ No newline at end of file
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py
deleted file mode 100644
index 05dd50a3f2bbc6637926c45f7c96f7d90e01edbf..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py
+++ /dev/null
@@ -1,198 +0,0 @@
-# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import gc
-import os
-from datetime import datetime, timezone
-
-from OpenSSL import crypto
-from cryptography import x509
-from cryptography.hazmat.backends import default_backend
-from dateutil import parser
-
-from msprobe.core.common.file_utils import FileOpen
-from msprobe.core.common.log import logger
-
-cipher_list = ":".join(
- ["TLS_DHE_RSA_WITH_AES_128_GCM_SHA256",
- "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384",
- "TLS_DHE_DSS_WITH_AES_128_GCM_SHA256",
- "TLS_DHE_DSS_WITH_AES_256_GCM_SHA384",
- "TLS_DHE_PSK_WITH_AES_128_GCM_SHA256",
- "TLS_DHE_PSK_WITH_AES_256_GCM_SHA384",
- "TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256",
- "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
- "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
- "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
- "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
- "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
- "TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256",
- "TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256",
- "TLS_ECDHE_PSK_WITH_AES_256_GCM_SHA384",
- "TLS_ECDHE_PSK_WITH_AES_128_CCM_SHA256",
- "TLS_DHE_RSA_WITH_AES_128_CCM",
- "TLS_DHE_RSA_WITH_AES_256_CCM",
- "TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
- "TLS_DHE_PSK_WITH_AES_128_CCM",
- "TLS_DHE_PSK_WITH_AES_256_CCM",
- "TLS_ECDHE_ECDSA_WITH_AES_128_CCM",
- "TLS_ECDHE_ECDSA_WITH_AES_256_CCM",
- "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"]
-).encode()
-
-STRUCT_UNPACK_MODE = "!Q"
-STR_TO_BYTES_ORDER = "big"
-
-
-def is_certificate_revoked(cert, crl):
- # 获取证书的序列号
- cert_serial_number = cert.get_serial_number()
-
- # 检查证书是否在CRL中
- revoked_serials = [revoked_cert.serial_number for revoked_cert in crl]
- if cert_serial_number in revoked_serials:
- logger.error(f"证书已吊销:{cert_serial_number:020x}")
- return True
-
- return False
-
-
-def verify_callback(conn, cert, errno, depth, preverify_ok, crl=None):
- """
- 验证对端证书的有效性
- :param conn: OpenSSL.SSL.Connection, SSL 连接对象
- :param cert: OpenSSL.crypto.X509, 当前证书
- :param errno: int, OpenSSL错误代码, 0:无错误 | 9:证书过期 | 18: 自签名证书
- :param depth: int, 当前证书在证书链中的深度 (0=叶子节点), 1:中间CA证书 -1:根CA证书 2+:更高级别CA证书
- :param preverify_ok: int, 验证结果 (1=通过, 0=失败)
- :param crl: _CRLInternal, CRL证书对象
- :return: bool, True表示接受证书, False表示拒绝
- """
-
- if not preverify_ok:
- from OpenSSL import SSL
- error_str = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)).decode()
- logger.error(f"证书验证失败 (depth={depth}, err={errno}): {error_str}")
- return False
-
- if crl and is_certificate_revoked(cert, crl):
- return False
-
- return preverify_ok
-
-
-def load_ssl_pem(key_file, cert_file, ca_file, crl_file):
- """
- Load SSL PEM files.
-
- Args:
- key_file (str): The path to the private key file.
- cert_file (str): The path to the certificate file.
- ca_file (str): The path to the CA certificate file.
- crl_file (str): The path to the CRL file.
-
- Returns:
- tuple: (key, crt, ca_crt, crl)
-
- Raises:
- Exception: If the file paths are invalid or the file contents are incorrect, exceptions may be thrown.
- """
-
- try:
- # your_private_key_password
- import pwinput
- passphrase = pwinput.pwinput("Enter your password: ")
- with FileOpen(key_file, "rb") as f:
- key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read(), passphrase.encode())
- del passphrase
- gc.collect()
- with FileOpen(cert_file, "rb") as f:
- crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
- check_crt_valid(crt)
-
- crt_serial_number = hex(crt.get_serial_number())[2:]
- logger.info(f"crt_serial_number: {crt_serial_number}")
-
- check_certificate_match(crt, key)
-
- with FileOpen(ca_file, "rb") as f:
- ca_crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
- check_crt_valid(ca_crt)
-
- ca_serial_number = hex(ca_crt.get_serial_number())[2:]
- logger.info(f"ca_serial_number: {ca_serial_number}")
- crl = None
- if os.path.exists(crl_file):
- with FileOpen(crl_file, "rb") as f:
- crl = x509.load_pem_x509_crl(f.read(), default_backend())
- check_crl_valid(crl, ca_crt)
- for revoked_cert in crl:
- logger.info(f"Serial Number: {revoked_cert.serial_number}, "
- f"Revocation Date: {revoked_cert.revocation_date_utc}")
-
- except Exception as e:
- raise RuntimeError(f"The SSL certificate is invalid") from e
-
- return key, crt, ca_crt, crl
-
-
-def check_crt_valid(pem):
- """
- Check the validity of the SSL certificate.
-
- Raises:
- RuntimeError: If the SSL certificate is invalid or expired.
- """
- try:
- pem_start = parser.parse(pem.get_notBefore().decode("UTF-8"))
- pem_end = parser.parse(pem.get_notAfter().decode("UTF-8"))
- logger.info(f"The SSL certificate passes the verification and the validity period "
- f"starts from {pem_start} ends at {pem_end}.")
- except Exception as e:
- raise RuntimeError(f"The SSL certificate is invalid") from e
-
- now_utc = datetime.now(tz=timezone.utc)
- if pem.has_expired() or not (pem_start <= now_utc <= pem_end):
- raise RuntimeError(f"The SSL certificate has expired.")
-
-
-def check_certificate_match(certificate, private_key):
- """
- Check certificate and private_key is match or not. if mismatched, an exception is thrown.
- :param certificate:
- :param private_key:
- :return:
- """
- test_data = os.urandom(256)
- try:
- signature = crypto.sign(private_key, test_data, "sha256")
- crypto.verify(
- certificate, # 包含公钥的证书
- signature, # 生成的签名
- test_data, # 原始数据
- "sha256", # 哈希算法
- )
- logger.info("公钥和私钥匹配")
- except Exception as e:
- raise RuntimeError("公钥和私钥不匹配") from e
-
-
-def check_crl_valid(crl, ca_crt):
- # 验证CRL签名(确保CRL未被篡改)
- if not crl.is_signature_valid(ca_crt.get_pubkey().to_cryptography_key()):
- raise RuntimeError("CRL签名无效!")
-
- # 检查CRL有效期
- if not (crl.last_update <= datetime.utcnow() <= crl.next_update):
- raise RuntimeError("CRL已过期或尚未生效!")
diff --git a/debug/accuracy_tools/msprobe/pytorch/attl_manager.py b/debug/accuracy_tools/msprobe/pytorch/attl_manager.py
deleted file mode 100644
index d24e85029bb462861527f8698715a390eaf64758..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/msprobe/pytorch/attl_manager.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# Copyright (c) 2025, Huawei Technologies Co., Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-from msprobe.core.common.runtime import Runtime
-from msprobe.core.common.utils import Const
-from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
-from msprobe.pytorch.common.log import logger
-
-
-class ATTLManager:
- def __init__(self, config):
- self.config = config
- self.attl = None
-
- def attl_init(self):
- if self.config.online_run_ut:
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL
- attl_config = ATTLConfig(is_benchmark_device=False,
- connect_ip=self.config.host,
- connect_port=self.config.port,
- nfs_path=self.config.nfs_path,
- tls_path=self.config.tls_path)
- need_dump = len(self.config.rank) == 0 or Runtime.current_rank in self.config.rank
- self.attl = ATTL('npu', attl_config, need_dump=need_dump)
- if self.config.nfs_path:
- self.attl.upload("start")
-
- def attl_send(self, name, args, kwargs, output):
- api_data = ApiData(
- name[:-len(Const.FORWARD_NAME_SUFFIX)],
- args,
- kwargs,
- output,
- Runtime.current_iter,
- Runtime.current_rank
- )
- logger.info(f"tools is dumping api: {api_data.name}, rank: {Runtime.current_rank}")
- api_type, _, _ = api_data.name.split(Const.SEP)
- if api_type in [Const.DISTRIBUTED]:
- logger.info(f"api {api_data.name} is not supported, skip")
- return
- if self.config.nfs_path:
- self.attl.upload(api_data)
- else:
- self.attl.send(api_data)
-
- def attl_stop(self):
- if self.config.nfs_path:
- self.attl.upload("end")
- elif self.attl.socket_manager is not None:
- logger.info(f"pid: {os.getpid()} finished, start sends STOP signal.")
- self.attl.socket_manager.send_stop_signal()
diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py
index 2f43f240112a5f6d7024702cd9a474caa49cff0b..004b8484878e3dc8d15eb7e4e9b4abec498417a6 100644
--- a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py
+++ b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py
@@ -48,16 +48,6 @@ class DebuggerConfig:
"max_sample": task_config.max_sample
}
- self.online_run_ut = False
- if self.task == Const.TENSOR:
- # dump api tensor and collaborate with online run_ut
- self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
- self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
- self.tls_path = task_config.tls_path if task_config.tls_path else ""
- self.host = task_config.host if task_config.host else ""
- self.port = task_config.port if task_config.port else -1
- self.online_run_ut_recompute = task_config.online_run_ut_recompute \
- if isinstance(task_config.online_run_ut_recompute, bool) else False
self.check()
self._check_statistics_config(task_config)
diff --git a/debug/accuracy_tools/msprobe/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py
index 62c0ce84879028eefc055138f9478035bf240b13..4e2770f0258b52366dcc1019feae71aacfda29fc 100644
--- a/debug/accuracy_tools/msprobe/pytorch/pt_config.py
+++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py
@@ -35,48 +35,15 @@ from msprobe.pytorch.hook_module.utils import get_ops
class TensorConfig(BaseConfig):
def __init__(self, json_config):
super().__init__(json_config)
- self.online_run_ut = json_config.get("online_run_ut", False)
- self.nfs_path = json_config.get("nfs_path", "")
- self.host = json_config.get("host", "")
- self.port = json_config.get("port", -1)
- self.tls_path = json_config.get("tls_path", "./")
- self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
self.check_config()
self._check_summary_mode()
self._check_file_format()
- if self.online_run_ut:
- self._check_online_run_ut()
+
def _check_file_format(self):
if self.file_format is not None and self.file_format not in ["npy", "bin"]:
raise Exception("file_format is invalid")
- def _check_online_run_ut(self):
- if not isinstance(self.online_run_ut, bool):
- raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.")
-
- if not isinstance(self.online_run_ut_recompute, bool):
- raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.")
-
- if self.nfs_path:
- check_file_or_directory_path(self.nfs_path, isdir=True)
- return
-
- if self.tls_path:
- check_file_or_directory_path(self.tls_path, isdir=True)
- check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
- check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
- check_file_or_directory_path(os.path.join(self.tls_path, "ca.crt"))
- crl_path = os.path.join(self.tls_path, "crl.pem")
- if os.path.exists(crl_path):
- check_file_or_directory_path(crl_path)
-
- if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
- raise Exception(f"host: {self.host} is invalid.")
-
- if not isinstance(self.port, int) or not (0 < self.port <= 65535):
- raise Exception(f"port: {self.port} is invalid, port range 0-65535.")
-
class StatisticsConfig(BaseConfig):
def __init__(self, json_config):
@@ -251,12 +218,7 @@ class RunUTConfig(BaseConfig):
self.white_list = json_config.get("white_list", Const.DEFAULT_LIST)
self.black_list = json_config.get("black_list", Const.DEFAULT_LIST)
self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH)
- self.is_online = json_config.get("is_online", False)
- self.nfs_path = json_config.get("nfs_path", "")
- self.host = json_config.get("host", "")
- self.port = json_config.get("port", -1)
- self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
- self.tls_path = json_config.get("tls_path", "./")
+
self.check_run_ut_config()
@classmethod
@@ -274,22 +236,11 @@ class RunUTConfig(BaseConfig):
if not os.path.exists(error_data_path):
raise Exception("error_data_path: %s does not exist" % error_data_path)
- @classmethod
- def check_nfs_path_config(cls, nfs_path):
- if nfs_path:
- FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
-
- @classmethod
- def check_tls_path_config(cls, tls_path):
- if tls_path:
- FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
def check_run_ut_config(self):
RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list)
RunUTConfig.check_error_data_path_config(self.error_data_path)
- RunUTConfig.check_nfs_path_config(self.nfs_path)
- RunUTConfig.check_tls_path_config(self.tls_path)
class GradToolConfig(BaseConfig):
diff --git a/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py
index e2181746bd04d648901fe1a063dc77dc02c70da9..4007553b6d80bade6f3f13456dab1cd95b4c5734 100644
--- a/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py
+++ b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py
@@ -15,7 +15,6 @@
from msprobe.core.common.utils import Const
from msprobe.core.service import BaseService
-from msprobe.pytorch.attl_manager import ATTLManager
from msprobe.pytorch.common.log import logger
from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_above_or_equal_2
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
@@ -25,9 +24,6 @@ from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, wrap_ji
from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager
from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
-if torch_version_above_or_equal_2:
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
-
class PytorchService(BaseService):
@property
@@ -45,12 +41,10 @@ class PytorchService(BaseService):
self.logger = logger
self.api_register = get_api_register()
self.module_processor = ModuleProcesser(self.data_collector.scope)
- self.attl_manager = ATTLManager(self.config)
- self.hook_manager = PytorchHookManager(self.data_collector, self.config, self.attl_manager)
+ self.hook_manager = PytorchHookManager(self.data_collector, self.config)
self.api_template = ApiTemplate
def _register_hook(self):
- self.attl_manager.attl_init()
if self._is_mix_level:
register_optimizer_hook(self.data_collector)
@@ -64,9 +58,6 @@ class PytorchService(BaseService):
self.module_processor.register_module_hook(self.model, self.build_hook)
self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.")
- def _run_ut_dispatch(self, status):
- if torch_version_above_or_equal_2:
- run_ut_dispatch(self.attl_manager.attl, status, self.config.online_run_ut_recompute)
def _reset_status(self):
super()._reset_status()
diff --git a/debug/accuracy_tools/msprobe/test/core_ut/test_hook_manager.py b/debug/accuracy_tools/msprobe/test/core_ut/test_hook_manager.py
index acab05db44dc02c3626fb69569ed1385f8ceaf17..593428e6735f5ce88c9652186a2609f66610bb0c 100644
--- a/debug/accuracy_tools/msprobe/test/core_ut/test_hook_manager.py
+++ b/debug/accuracy_tools/msprobe/test/core_ut/test_hook_manager.py
@@ -55,11 +55,9 @@ class TestBaseHookManager(unittest.TestCase):
self.mock_data_collector = MagicMock()
self.mock_config = MagicMock()
self.mock_config.data_mode = ["all"]
- self.mock_attl_manager = MagicMock()
self.manager = self.MockBaseHookManager(
self.mock_data_collector,
- self.mock_config,
- self.mock_attl_manager
+ self.mock_config
)
BaseHookManager.inner_switch[threading.get_ident()] = False
BaseHookManager.hook_handle_dict = {}
@@ -68,7 +66,6 @@ class TestBaseHookManager(unittest.TestCase):
def test_init(self):
self.assertEqual(self.manager.data_collector, self.mock_data_collector)
self.assertEqual(self.manager.config, self.mock_config)
- self.assertEqual(self.manager.attl_manager, self.mock_attl_manager)
def test_should_execute_hook_conditions(self):
module = MagicMock()
@@ -130,7 +127,6 @@ class TestBaseHookManager(unittest.TestCase):
def test_forward_pre_hook_behavior(self, mock_should_execute_hook, mock_release):
mock_should_execute_hook.return_value = True
mock_release.return_value = None
- self.manager.config.online_run_ut = None
hook = self.manager._build_forward_pre_hook(Const.API, "api_name", "func_name")
module = MagicMock()
module.msprobe_input_kwargs = {"kwarg": "value"}
@@ -154,11 +150,6 @@ class TestBaseHookManager(unittest.TestCase):
kwargs = {"kwargs": []}
output = MagicMock()
- self.manager.config.online_run_ut = True
- hook(module, args, output)
- self.mock_attl_manager.attl_send.assert_called_once()
-
- self.manager.config.online_run_ut = None
self.mock_data_collector.if_return_forward_new_output.return_value = False
with patch.object(self.manager, '_get_params_dict', return_value={}):
result = hook(module, args, kwargs, output)
@@ -175,7 +166,6 @@ class TestBaseHookManager(unittest.TestCase):
@patch.object(BaseHookManager, "_should_execute_hook")
def test_backward_hook_behavior(self, mock_should_execute_hook):
mock_should_execute_hook.return_value = True
- self.manager.config.online_run_ut = None
hook = self.manager._build_backward_hook(Const.API, "api_name")
module = MagicMock()
grad_input = (MagicMock(),)
diff --git a/debug/accuracy_tools/msprobe/test/core_ut/test_service.py b/debug/accuracy_tools/msprobe/test/core_ut/test_service.py
index e440cf14f236f4a87327c6fa549da66984eb7b06..5a241790055a8a68b1a0c8c7c7eb21c6bacd4e1d 100644
--- a/debug/accuracy_tools/msprobe/test/core_ut/test_service.py
+++ b/debug/accuracy_tools/msprobe/test/core_ut/test_service.py
@@ -61,7 +61,6 @@ class TestBaseService(unittest.TestCase):
self.config.task = Const.STATISTICS
self.config.async_dump = True
self.config.tensor_list = []
- self.config.online_run_ut = False
self.config.framework = "test_framwork"
with patch('msprobe.core.service.build_data_collector'):
@@ -314,12 +313,9 @@ class TestBaseService(unittest.TestCase):
def test_need_stop_service_conditions(self):
self.service.current_iter = 4
self.service.config.step = [1, 2, 3]
- self.service.config.online_run_ut = True
- self.service.attl_manager = MagicMock()
self.assertTrue(self.service._need_stop_service())
self.assertFalse(Runtime.is_running)
self.assertFalse(self.service.primitive_switch)
- self.service.attl_manager.attl_stop.assert_called()
self.service.current_iter = 1
self.service.data_collector.data_processor.is_terminated = True
diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py
index 1d777a52752b9d8273f8e3604269af0383677e18..7777ab41878829a290b396565a90b8001ab86977 100644
--- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py
+++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py
@@ -131,14 +131,12 @@ class TestMindsporeService(unittest.TestCase):
@patch('msprobe.mindspore.mindspore_service.JitDump')
def test_start_jit_enabled(self, mock_jit_dump):
self.service.data_collector.data_processor.is_terminated = False
- self.service.config.online_run_ut = None
model_mock = MagicMock()
self.service.start(model=model_mock)
self.assertTrue(mock_jit_dump.jit_dump_switch)
@patch('msprobe.mindspore.mindspore_service.JitDump')
def test_stop_jit_disabled(self, mock_jit_dump):
- self.service.config.online_run_ut = None
self.config.level = Const.LEVEL_MIX
self.service.current_iter = 1
self.service.current_rank = 0
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py
index df03485dc6c77371750fd0b67ca2c37ff7e2ed7b..7c82585324effcac9a08dd1d2d5827c894311775 100644
--- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py
@@ -2,7 +2,7 @@ import unittest
import os
from unittest.mock import patch
-from msprobe.pytorch.api_accuracy_checker.common.config import Config, CheckerConfig, OnlineConfig, msCheckerConfig
+from msprobe.pytorch.api_accuracy_checker.common.config import Config, CheckerConfig, msCheckerConfig
class TestUtConfig():
@@ -10,12 +10,6 @@ class TestUtConfig():
self.white_list = ['api1', 'api2']
self.black_list = ['api3']
self.error_data_path = '/path/to/error_data'
- self.is_online = True
- self.nfs_path = '/path/to/nfs'
- self.host = 'localhost'
- self.port = 8080
- self.rank_list = [0, 1, 2]
- self.tls_path = '/path/to/tls'
class TestConfig(unittest.TestCase):
@@ -60,46 +54,19 @@ class TestConfig(unittest.TestCase):
self.assertEqual(checker_config.white_list, msCheckerConfig.white_list)
self.assertEqual(checker_config.black_list, msCheckerConfig.black_list)
self.assertEqual(checker_config.error_data_path, msCheckerConfig.error_data_path)
- self.assertEqual(checker_config.is_online, msCheckerConfig.is_online)
- self.assertEqual(checker_config.nfs_path, msCheckerConfig.nfs_path)
- self.assertEqual(checker_config.host, msCheckerConfig.host)
- self.assertEqual(checker_config.port, msCheckerConfig.port)
- self.assertEqual(checker_config.rank_list, msCheckerConfig.rank_list)
- self.assertEqual(checker_config.tls_path, msCheckerConfig.tls_path)
+
def test_init_with_task_config(self):
checker_config = CheckerConfig(self.task_config)
self.assertEqual(checker_config.white_list, self.task_config.white_list)
self.assertEqual(checker_config.black_list, self.task_config.black_list)
self.assertEqual(checker_config.error_data_path, self.task_config.error_data_path)
- self.assertEqual(checker_config.is_online, self.task_config.is_online)
- self.assertEqual(checker_config.nfs_path, self.task_config.nfs_path)
- self.assertEqual(checker_config.host, self.task_config.host)
- self.assertEqual(checker_config.port, self.task_config.port)
- self.assertEqual(checker_config.rank_list, self.task_config.rank_list)
- self.assertEqual(checker_config.tls_path, self.task_config.tls_path)
+
def test_load_config(self):
checker_config = CheckerConfig()
checker_config.load_config(self.task_config)
- self.assertEqual(checker_config.is_online, self.task_config.is_online)
- self.assertEqual(checker_config.nfs_path, self.task_config.nfs_path)
- self.assertEqual(checker_config.host, self.task_config.host)
- self.assertEqual(checker_config.port, self.task_config.port)
- self.assertEqual(checker_config.rank_list, self.task_config.rank_list)
- self.assertEqual(checker_config.tls_path, self.task_config.tls_path)
-
- def test_get_online_config(self):
- checker_config = CheckerConfig()
- checker_config.load_config(self.task_config)
- online_config = checker_config.get_online_config()
- self.assertIsInstance(online_config, OnlineConfig)
- self.assertEqual(online_config.is_online, self.task_config.is_online)
- self.assertEqual(online_config.nfs_path, self.task_config.nfs_path)
- self.assertEqual(online_config.host, self.task_config.host)
- self.assertEqual(online_config.port, self.task_config.port)
- self.assertEqual(online_config.rank_list, self.task_config.rank_list)
- self.assertEqual(online_config.tls_path, self.task_config.tls_path)
+
def test_get_run_ut_config(self):
forward_content = {'api1': 'data1', 'api2': 'data2'}
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py
index 15a7908ad8de6d4883e0574ceaf451a03dbfbfe3..9e14e035ab6550250c3b992c653fefcb1f9dc8d1 100644
--- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py
@@ -350,59 +350,7 @@ class TestApiPrecisionCompare(unittest.TestCase):
if os.path.exists(base_path):
os.rmdir(base_path)
- def test_online_api_precision_compare(self):
- # 准备测试目录和文件
- base_path = 'test_online_compare_tmp'
- os.makedirs(base_path, exist_ok=True)
-
- # 创建测试用的CSV文件
- npu_csv = os.path.join(base_path, 'npu.csv')
- gpu_csv = os.path.join(base_path, 'gpu.csv')
- result_csv = os.path.join(base_path, 'results_rank1.csv')
- details_csv = os.path.join(base_path, 'details_rank1.csv')
-
- # 准备在线比较的配置
- online_config = MagicMock()
- online_config.rank = 1
- online_config.result_csv_path = os.path.join(base_path, "results_rank*.csv")
- online_config.details_csv_path = os.path.join(base_path, "details_rank*.csv")
-
- # 将测试数据写入CSV文件
- df = pd.DataFrame(self.test_data)
- df.to_csv(npu_csv, index=False)
- df.to_csv(gpu_csv, index=False)
-
- # 设置online_config的数据
- online_config.npu_data = pd.read_csv(npu_csv)
- online_config.gpu_data = pd.read_csv(gpu_csv)
-
- try:
- # 执行在线比较
- online_api_precision_compare(online_config)
-
- # 验证结果文件是否生成
- self.assertTrue(os.path.exists(result_csv))
- self.assertTrue(os.path.exists(details_csv))
-
- # 读取并验证结果
- result_df = pd.read_csv(result_csv)
- self.assertFalse(result_df.empty)
-
- details_df = pd.read_csv(details_csv)
- self.assertFalse(details_df.empty)
-
- # 验证文件权限
- self.assertEqual(os.stat(result_csv).st_mode & 0o777, FileCheckConst.DATA_FILE_AUTHORITY)
- self.assertEqual(os.stat(details_csv).st_mode & 0o777, FileCheckConst.DATA_FILE_AUTHORITY)
-
- finally:
- # 清理测试文件
- for file_path in [npu_csv, gpu_csv, result_csv, details_csv]:
- if os.path.exists(file_path):
- os.remove(file_path)
- if os.path.exists(base_path):
- os.rmdir(base_path)
-
+
def test_skip_due_to_empty_output(self):
self.row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] = ' '
api_name = "abs"
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py
index 13bf0a5b19c49576101c8f4daf0d609ee625aefe..bd07015d37ab9c1bac067db73fffe8287062526a 100644
--- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py
@@ -284,77 +284,5 @@ class TestRunUtMethods(unittest.TestCase):
self.temp_dir.cleanup()
-class TestRunUtOnlineConfig(unittest.TestCase):
-
- def test_checked_online_config(self):
- class OnlineConfigClass:
- is_online = True
- rank_list = [0, 1]
- nfs_path = ""
- tls_path = ""
- host = "127.0.0.1"
- port = 12345
-
- online_config = OnlineConfigClass()
- res = checked_online_config(online_config)
- self.assertIsNone(res)
-
- # test is_online
- online_config.is_online = "True"
- with self.assertRaises(Exception) as context:
- checked_online_config(online_config)
- self.assertIn(str(context.exception), f"is_online must be bool type")
- online_config.is_online = True
-
- # test rank_list
- online_config.rank_list = "1234"
- with self.assertRaises(Exception) as context:
- checked_online_config(online_config)
- self.assertIn(str(context.exception), f"rank_list must be a list")
- online_config.rank_list = ["1", "2"]
- with self.assertRaises(Exception) as context:
- checked_online_config(online_config)
- self.assertIn(str(context.exception), f"All elements in rank_list must be integers")
- online_config.rank_list = [1, 2]
-
- # test nfs_path
- online_config.nfs_path = "./nfs_path"
- with self.assertRaises(Exception) as context:
- checked_online_config(online_config)
- self.assertIn(str(context.exception), "[msprobe] 非法文件路径: ")
- online_config.nfs_path = ""
-
- # test tls_path
- online_config.tls_path = "./tls_path"
- with self.assertRaises(Exception) as context:
- checked_online_config(online_config)
- self.assertIn(str(context.exception), "[msprobe] 非法文件路径: ")
-
- os.makedirs(online_config.tls_path)
- with open(os.path.join(online_config.tls_path, "server.key"), 'w') as file:
- file.write("1")
- with open(os.path.join(online_config.tls_path, "server.crt"), 'w') as file:
- file.write("1")
- with open(os.path.join(online_config.tls_path, "ca.crt"), 'w') as file:
- file.write("1")
- checked_online_config(online_config)
- shutil.rmtree(online_config.tls_path)
- online_config.tls_path = ""
-
- # test host
- online_config.host = "invalid_host"
- with self.assertRaises(Exception) as context:
- checked_online_config(online_config)
- self.assertIn(str(context.exception), f"host: {online_config.host} is invalid.")
- online_config.host = "127.0.0.1"
-
- # test port
- online_config.port = -1
- with self.assertRaises(Exception) as context:
- checked_online_config(online_config)
- self.assertIn(str(context.exception), f"port: {online_config.port} is invalid, port range 0-65535.")
- online_config.port = 6123
-
-
if __name__ == '__main__':
unittest.main()
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_attl.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_attl.py
deleted file mode 100644
index 7d4e6e950dc1d3e51ef69ca46895fcf5078c5f67..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_attl.py
+++ /dev/null
@@ -1,108 +0,0 @@
-# coding=utf-8
-import unittest
-from unittest.mock import patch
-from multiprocessing import Queue
-
-from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import *
-from msprobe.core.common.file_utils import create_directory
-
-class TestATTL(unittest.TestCase):
-
- def setUp(self):
- nfs_path = "temp_nfs_path"
- create_directory(nfs_path)
- self.nfs_path = os.path.realpath(nfs_path)
- self.session_id = "test_session"
- self.session_config = ATTLConfig(is_benchmark_device=False, connect_ip='127.0.0.1',
- connect_port=8080, nfs_path=self.nfs_path , check_sum=False, queue_size=100)
- self.attls = ATTL(self.session_id, self.session_config, need_dump=False)
- self.buffer = ApiData('test_api', args=(torch.randn(2, 2),), kwargs={'device': 'cpu'},
- result=torch.randn(2, 2), step=1, rank=1)
-
- def tearDown(self):
- for filename in os.listdir(self.nfs_path):
- os.remove(os.path.join(self.nfs_path, filename))
- os.rmdir(self.nfs_path)
-
- def test_attl_config(self):
- config = ATTLConfig(is_benchmark_device=True, connect_ip='192.168.1.1', connect_port=9090,
- nfs_path=self.nfs_path, tls_path='/path/to/tls', check_sum=False, queue_size=100)
- self.assertEqual(config.is_benchmark_device, True)
- self.assertEqual(config.connect_ip, '192.168.1.1')
- self.assertEqual(config.connect_port, 9090)
- self.assertEqual(config.nfs_path, self.nfs_path)
- self.assertEqual(config.tls_path, '/path/to/tls')
- self.assertFalse(config.check_sum)
- self.assertEqual(config.queue_size, 100)
-
- @patch('msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl.move2target_device')
- def test_upload_api_data(self, mock_move2target_device):
- mock_move2target_device.return_value = self.buffer
- self.attls.upload(self.buffer)
- mock_move2target_device.assert_called_once_with(self.buffer, torch.device('cpu'))
-
- @patch('glob.glob')
- def test_download_no_files(self, mock_glob):
- mock_glob.return_value = []
- result = self.attls.download()
- self.assertIsNone(result)
-
- @patch('glob.glob')
- @patch('msprobe.pytorch.common.utils.load_pt')
- def test_download_with_exception(self, mock_load_pt, mock_glob):
- mock_glob.return_value = ['/tmp/start_file.pt']
- mock_load_pt.side_effect = Exception('Load error')
- with patch.object(self.attls.logger, 'warning') as mock_logger:
- result = self.attls.download()
- self.assertIsNone(result)
- mock_logger.assert_called_once()
-
- def test_move2device_exec_tensor(self):
- tensor = torch.randn(2, 2)
- device = torch.device("cpu")
- moved_tensor = move2device_exec(tensor, device)
- self.assertEqual(moved_tensor.device, device)
-
- def test_move2device_exec_list(self):
- tensor_list = [torch.randn(2, 2), torch.randn(2, 2)]
- device = torch.device("cpu")
- moved_list = move2device_exec(tensor_list, device)
- for tensor in moved_list:
- self.assertEqual(tensor.device, device)
-
- def test_move2device_exec_tuple(self):
- tensor_tuple = (torch.randn(2, 2), torch.randn(2, 2))
- device = torch.device("cpu")
- moved_tuple = move2device_exec(tensor_tuple, device)
- for tensor in moved_tuple:
- self.assertEqual(tensor.device, device)
-
- def test_move2device_exec_dict(self):
- tensor_dict = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)}
- device = torch.device("cpu")
- moved_dict = move2device_exec(tensor_dict, device)
- for tensor in moved_dict.values():
- self.assertEqual(tensor.device, device)
-
- def test_move2device_exec_device(self):
- device = torch.device("cpu")
- moved_device = move2device_exec(torch.device("cpu"), device)
- self.assertEqual(moved_device, device)
-
- def test_move2device_exec_non_tensor(self):
- obj = "This is a string"
- device = torch.device("cpu")
- self.assertEqual(move2device_exec(obj, device), obj)
-
- def test_move2target_device_to_cpu(self):
- tensor_args = (torch.randn(2, 2), torch.randn(3, 3))
- tensor_kwargs = {'key1': torch.randn(2, 2), 'key2': torch.randn(3, 3)}
- tensor_result = torch.randn(2, 2)
- buffer = ApiData('test_api', tensor_args, tensor_kwargs, tensor_result, 1, 1)
- target_device = torch.device('cpu')
- moved_buffer = move2target_device(buffer, target_device)
- self.assertEqual(moved_buffer.result.device, target_device)
- for tensor in moved_buffer.args:
- self.assertEqual(tensor.device, target_device)
- for tensor in moved_buffer.kwargs.values():
- self.assertEqual(tensor.device, target_device)
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_client.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_client.py
deleted file mode 100644
index d35cfc3387559064298a451fb9d868838bb25aac..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_client.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# coding=utf-8
-import unittest
-from unittest.mock import patch, MagicMock
-from multiprocessing import Queue
-
-from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import *
-from msprobe.core.common.file_utils import create_directory
-
-
-class TestClient(unittest.TestCase):
-
- def setUp(self) -> None:
- self.host = "localhost"
- self.port = 8000
- self.check_sum = False
- tls_path = "temp_tls_path"
- create_directory(tls_path)
- self.tls_path = os.path.realpath(tls_path)
-
- def tearDown(self) -> None:
- for filename in os.listdir(self.tls_path):
- os.remove(os.path.join(self.tls_path, filename))
- os.rmdir(self.tls_path)
-
- def test_TCPDataItem(self):
- data_item = TCPDataItem(data="example_data", sequence_number=10, rank=1, step=2)
- self.assertEqual(data_item.raw_data, "example_data")
- self.assertEqual(data_item.sequence_number, 10)
- self.assertEqual(data_item.rank, 1)
- self.assertEqual(data_item.step, 2)
- self.assertEqual(data_item.retry_times, 0)
- self.assertEqual(data_item.pending_time, 0)
- self.assertEqual(data_item.busy_time, 0)
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_accuracy_server.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_accuracy_server.py
deleted file mode 100644
index 726714b7993081044a2ca6909db357d3995ad296..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_accuracy_server.py
+++ /dev/null
@@ -1,173 +0,0 @@
-# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import io
-import queue
-import struct
-import time
-import unittest
-from unittest.mock import MagicMock, patch
-
-from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import (
- TCPServer,
- ServerProtocol,
- MessageServerFactory
-)
-
-
-class TestTCPServer(unittest.TestCase):
- def setUp(self):
- self.shared_queue = queue.Queue()
- self.tcp_server = TCPServer("6000", self.shared_queue)
- self.tcp_server.tls_path = "/test/path"
- self.tcp_server.factory = MagicMock()
-
- @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server.reactor")
- def test_run_reactor(self, mock_reactor):
- self.tcp_server.run_reactor()
- mock_reactor.run.assert_called_once_with(installSignalHandlers=False)
-
- def test_is_running(self):
- self.tcp_server.is_running()
- self.tcp_server.factory.is_all_connection_closed.assert_called_once_with()
-
- @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server.reactor")
- def test_stop(self, mock_reactor):
- self.tcp_server.reactor_thread = MagicMock()
- self.tcp_server.stop()
- mock_reactor.callFromThread.assert_called_once()
- self.tcp_server.reactor_thread.join.assert_called_once()
-
-
-class TestServerProtocol(unittest.TestCase):
- def setUp(self):
- self.shared_queue = queue.Queue()
- self.server_protocol = ServerProtocol(self.shared_queue)
- self.server_protocol.start_time = time.time()
- self.server_protocol.factory = MagicMock()
- self.server_protocol.factory.transport_dict = {}
- self.server_protocol.factory.transport_list = []
- self.server_protocol.transport = MagicMock()
-
- def test_connectionMade(self):
- self.server_protocol.connectionMade()
- self.assertEqual(self.server_protocol.tell, 0)
- self.assertEqual(self.server_protocol.factory.transport_dict[self.server_protocol.transport], 1)
- self.assertTrue(self.server_protocol.transport in self.server_protocol.factory.transport_list)
-
- def test_connectionLost(self):
- self.server_protocol.factory.transport_dict[self.server_protocol.transport] = 1
- self.server_protocol.connectionLost("test")
- self.assertEqual(len(self.server_protocol.factory.transport_dict), 0)
- self.assertEqual(self.server_protocol.consumer_queue.get(), self.server_protocol.ACK_KILL_PROCESS)
-
- def test_send_ack(self):
- self.server_protocol.sequence_number = 1
- self.server_protocol.rank = 0
- self.server_protocol.step = 0
- self.server_protocol.send_ack(b'test message')
- expected_value = b''.join([
- b'test message',
- b'\x00\x00\x00\x00\x00\x00\x00\x01',
- b'\x00\x00\x00\x00\x00\x00\x00\x00',
- b'\x00\x00\x00\x00\x00\x00\x00\x00',
- ])
- self.server_protocol.transport.write.called_once_with(expected_value)
-
- @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server.zlib.crc32")
- def test_post_process_error(self, mock_zlib_crc32):
- self.shared_queue.maxsize = 1
- self.server_protocol.send_ack = MagicMock()
-
- def mock_send_ack_method1():
- self.server_protocol.consumer_queue.put(1)
-
- def mock_send_ack_method2():
- pass
-
- self.server_protocol.send_ack.side_effect = [mock_send_ack_method1, mock_send_ack_method2]
- self.server_protocol.check_sum = True
- mock_zlib_crc32.return_value = 123
- self.server_protocol.rank = 0
- self.server_protocol.step = 0
- self.server_protocol.post_process()
- mock_zlib_crc32.assert_called()
- self.server_protocol.send_ack.assert_any_call(self.server_protocol.ACK_ERROR)
- self.assertEqual(self.server_protocol.rank, -1)
- self.assertEqual(self.server_protocol.step, -1)
-
- @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server.zlib.crc32")
- def test_post_process_success(self, mock_zlib_crc32):
- mock_zlib_crc32.return_value = 123
- self.shared_queue.maxsize = 1
- self.server_protocol.send_ack = MagicMock()
-
- def mock_send_ack_method1():
- self.server_protocol.consumer_queue.put(1)
-
- def mock_send_ack_method2():
- pass
-
- self.server_protocol.send_ack.side_effect = [mock_send_ack_method1, mock_send_ack_method2]
- self.server_protocol.check_sum = False
- self.server_protocol.obj_body = self.server_protocol.ACK_SUCCESS
- self.server_protocol.post_process()
- self.server_protocol.send_ack.assert_any_call(self.server_protocol.ACK_SUCCESS)
-
- def test_handle_with_stop(self):
- self.server_protocol.send_ack = MagicMock()
- self.server_protocol.handle_with_stop()
- self.server_protocol.send_ack.assert_called_once_with(self.server_protocol.ACK_STOP_CONFIRM)
- self.assertEqual(self.server_protocol.consumer_queue.get(), self.server_protocol.ACK_KILL_PROCESS)
-
- def test_reset_env(self):
- self.server_protocol.obj_length = 10
- self.server_protocol.sequence_number = 1
- self.server_protocol.rank = 2
- self.server_protocol.step = 3
- self.server_protocol.reset_env()
- self.assertEqual(self.server_protocol.obj_length, None)
- self.assertEqual(self.server_protocol.sequence_number, -1)
- self.assertEqual(self.server_protocol.rank, -1)
- self.assertEqual(self.server_protocol.step, -1)
-
- def test_dataReceived(self):
- self.server_protocol.buffer = io.BytesIO()
- self.server_protocol.post_process = MagicMock()
- unpack_mode = '!Q'
- header = struct.pack(unpack_mode, 10)
- header += struct.pack(unpack_mode, 1)
- header += struct.pack(unpack_mode, 2)
- header += struct.pack(unpack_mode, 3)
-
- self.server_protocol.dataReceived(header)
-
- self.assertEqual(self.server_protocol.obj_length, 10)
- self.assertEqual(self.server_protocol.sequence_number, 1)
- self.assertEqual(self.server_protocol.rank, 2)
- self.assertEqual(self.server_protocol.step, 3)
-
-
-class TestMessageServerFactory(unittest.TestCase):
- def setUp(self):
- self.message_server_factory = MessageServerFactory()
-
- def test_is_all_connection_closed(self):
- all_conn_closed = self.message_server_factory.is_all_connection_closed()
- self.assertTrue(all_conn_closed)
-
- self.message_server_factory.transport_dict = {"test1": 1}
- all_conn_closed = self.message_server_factory.is_all_connection_closed()
- self.assertFalse(all_conn_closed)
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_device_dispatch.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_device_dispatch.py
deleted file mode 100644
index 5df5ee879287931512dcca5f7de2daca5bcef284..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_device_dispatch.py
+++ /dev/null
@@ -1,122 +0,0 @@
-# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-from unittest.mock import MagicMock, patch
-
-from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import run_ut_process, \
- online_precision_compare, online_compare, ConsumerDispatcher
-from msprobe.pytorch.common.log import logger
-
-
-class TestDeviceDispatchFunc(unittest.TestCase):
- @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.online_compare")
- @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.torch")
- def test_run_ut_process(self, mock_torch, mock_online_compare):
- xpu_id = 1
- mock_consumer_queue = MagicMock()
- mock_consumer_queue.empty.side_effect = [True, False, False]
- mock_api_data = MagicMock()
- mock_api_data.name.split.return_value = ("test", "conv2d", 1)
- mock_consumer_queue.get.side_effect = [mock_api_data, "KILL_"]
-
- run_ut_process(xpu_id, mock_consumer_queue, None, None)
- mock_online_compare.assert_called_with(mock_api_data, mock_torch.device(), None)
-
- @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.UtDataInfo")
- @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.exec_api")
- @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.generate_cpu_params")
- def test_online_precision_compare(self, mock_gen_cpu_params, mock_exec_api, mock_ut_data_info):
- with patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.move2target_device"), \
- patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.pd"), \
- patch(
- "msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.online_api_precision_compare"):
- mock_gen_cpu_params.return_value = (MagicMock())
- mock_api_data = MagicMock()
- mock_api_data.name.split.return_value = ("tensor", "conv2d", 1)
- mock_com_config = MagicMock()
- mock_api_precision_csv_file = [MagicMock(), MagicMock()]
- online_precision_compare(mock_api_data, None, mock_com_config, mock_api_precision_csv_file)
- mock_gen_cpu_params.assert_called()
- mock_exec_api.assert_called()
- mock_ut_data_info.assert_called()
-
- @patch.object(logger, "info")
- @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.move2target_device")
- def test_online_compare_success(self, mock_move2target_device, mock_logger_info):
- api_data = MagicMock()
- api_data.name = "test_api_name"
- common_config = MagicMock()
- common_config.compare.compare_output.return_value = ("test_fwd_success", "test_bwd_success")
- online_compare(api_data, None, common_config)
- mock_move2target_device.assert_called()
- mock_logger_info.assert_called_once_with("running api_full_name test_api_name ut, "
- "is_fwd_success: test_fwd_success, "
- "is_bwd_success: test_bwd_success")
-
- @patch.object(logger, "error")
- @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.move2target_device")
- def test_online_compare_failed(self, mock_move2target_device, mock_logger_error):
- api_data = MagicMock()
- api_data.name.split.return_value = ["tensor", "conv2d", 1]
- common_config = MagicMock()
- online_compare(api_data, None, common_config)
- mock_move2target_device.assert_called()
- mock_logger_error.assert_called()
-
-
-class TestConsumerDispatcher(unittest.TestCase):
- @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.mp")
- def setUp(self, mock_mq):
- self.mock_mq = mock_mq
- self.consumer_dispatcher = ConsumerDispatcher(None)
-
- @patch.object(logger, "info")
- @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.mp")
- @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.CommonCompareConfig")
- def test_start(self, mock_com_compare_config, mock_mq, mock_log_info):
- self.consumer_dispatcher.start(None, None)
- mock_com_compare_config.assert_called_once_with(None, None, None)
- mock_mq.Process.assert_called()
- mock_log_info.assert_any_call("Successfully start unittest process.")
-
- @patch.object(logger, "info")
- def test_stop(self, mock_log_info):
- mock_queue = MagicMock()
- mock_queue.full.side_effect = [True, False]
- self.consumer_dispatcher.queues = [mock_queue]
-
- mock_process = MagicMock()
- self.consumer_dispatcher.processes = [mock_process]
- self.consumer_dispatcher.stop()
- mock_log_info.assert_any_call("Successfully stop unittest process.")
- mock_process.join.assert_called()
-
- def test_update_consume_queue(self):
- self.consumer_dispatcher._choose_max_empty_site_strategy = MagicMock()
- self.consumer_dispatcher._choose_max_empty_site_strategy.return_value = 0
- mock_queue = MagicMock()
- self.consumer_dispatcher.queues = [mock_queue]
- self.consumer_dispatcher.update_consume_queue("test_data")
- mock_queue.put.assert_called_once_with("test_data")
-
- def test_choose_max_empty_site_strategy(self):
- mock_queue = MagicMock()
- mock_queue.qsize.return_value = 1
- self.consumer_dispatcher.queues = [mock_queue]
- self.consumer_dispatcher.capacity = 5
- self.consumer_dispatcher.reverse_sort = False
- result = self.consumer_dispatcher._choose_max_empty_site_strategy()
- self.assertEqual(result, 0)
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_ttl_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_ttl_utils.py
deleted file mode 100644
index b363c5f0316ad7f66d660b5644b5b25a104be82f..0000000000000000000000000000000000000000
--- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_ttl_utils.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import unittest
-from unittest.mock import Mock, patch
-
-from OpenSSL import crypto, SSL
-
-from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import verify_callback, is_certificate_revoked
-
-
-class TestVerifyCallback(unittest.TestCase):
- """
- Test for verify_callback and is_certificate_revoked.
- """
-
- def setUp(self):
- self.conn = Mock(spec=SSL.Connection)
- self.cert = Mock(spec=crypto.X509)
- self.crl = [Mock()]
- self.crl[0].serial_number = 89981275109692867917699502952114227065605526936
-
- @patch('msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils.is_certificate_revoked')
- def test_preverify_ok(self, mock_is_certificate_revoked):
- mock_is_certificate_revoked.return_value = False
- self.assertTrue(verify_callback(self.conn, self.cert, 0, 0, 1, self.crl))
-
- @patch('msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils.is_certificate_revoked')
- def test_preverify_not_ok(self, mock_is_certificate_revoked):
- self.assertFalse(verify_callback(self.conn, self.cert, 0, 0, 0, None))
-
- mock_is_certificate_revoked.return_value = False
- self.assertEqual(verify_callback(self.conn, self.cert, 0, 0, 1, self.crl), 1)
-
- def test_is_certificate_revoked_true(self):
- self.cert.get_serial_number.return_value = 89981275109692867917699502952114227065605526936
- result = is_certificate_revoked(self.cert, self.crl)
- self.assertTrue(result)
-
- def test_is_certificate_revoked_false(self):
- self.cert.get_serial_number.return_value = 89981275109692867917699502952114227065605526937
- result = is_certificate_revoked(self.cert, self.crl)
- self.assertFalse(result)
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py
index f086c61c9039d096a66fe437672bb26b5b464295..56af1d2198d2270733cf533fe0ea556598bcf30b 100644
--- a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py
@@ -35,17 +35,6 @@ class TestDebuggerConfig(unittest.TestCase):
self.assertEqual(debugger.handler_type, "check")
self.assertTrue(debugger.preheat_config["if_preheat"])
- def test_online_run_ut_initialization(self):
- self.task_config.online_run_ut = True
- self.task_config.nfs_path = "./nfs_path"
- self.task_config.tls_path = "./tls_path"
- self.task_config.host = "localhost"
- self.task_config.port = 8080
-
- debugger = DebuggerConfig(self.common_config, self.task_config, Const.TENSOR, None, None)
- self.assertTrue(debugger.online_run_ut)
- self.assertEqual(debugger.nfs_path, "./nfs_path")
- self.assertEqual(debugger.port, 8080)
def test_check_kwargs_with_invalid_task(self):
self.common_config.task = "invalid_task"
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py
index bcf6fb501b9b95f324de78622ec3265f50747212..2712281bef7cafe1e4e2ad82def5eb13c7716f9b 100644
--- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py
@@ -4,7 +4,7 @@ import unittest
from unittest.mock import patch
from msprobe.core.common.const import Const
-from msprobe.pytorch.pt_config import parse_json_config, parse_task_config, TensorConfig, \
+from msprobe.pytorch.pt_config import parse_json_config, parse_task_config, \
StatisticsConfig, OverflowCheckConfig, FreeBenchmarkCheckConfig, RunUTConfig, GradToolConfig
@@ -82,83 +82,6 @@ class TestPtConfig(unittest.TestCase):
self.assertEqual(result.error_data_path, '/home/dump_path')
-class TestTensorConfig(unittest.TestCase):
-
- def setUp(self):
- self.json_config = {
- "online_run_ut": False,
- "host": "127.0.0.1",
- "port": 8080
- }
- self.config = TensorConfig(self.json_config)
-
- def test_check_file_format_valid(self):
- self.config.file_format = "npy"
- self.config._check_file_format()
-
- self.config.file_format = "bin"
- self.config._check_file_format()
-
- def test_check_file_format_invalid(self):
- self.config.file_format = "invalid_format"
- with self.assertRaises(Exception) as context:
- self.config._check_file_format()
- self.assertIn(str(context.exception), "file_format is invalid")
-
- def test_check_online_run_ut(self):
-
- self.config.online_run_ut = "True"
- with self.assertRaises(Exception) as context:
- self.config._check_online_run_ut()
- self.assertIn(str(context.exception), f"online_run_ut: {self.config.online_run_ut} is invalid.")
- self.config.online_run_ut = True
-
- self.config.online_run_ut_recompute = "True"
- with self.assertRaises(Exception) as context:
- self.config._check_online_run_ut()
- self.assertIn(str(context.exception), f"online_run_ut_recompute: {self.config.online_run_ut} is invalid.")
- self.config.online_run_ut_recompute = False
-
- self.config.nfs_path = "./nfs_path"
- with self.assertRaises(Exception) as context:
- self.config._check_online_run_ut()
- self.assertIn(str(context.exception), "[msprobe] 非法文件路径: ")
- self.config.nfs_path = ""
-
- self.config.tls_path = "./tls_path"
- with self.assertRaises(Exception) as context:
- self.config._check_online_run_ut()
- self.assertIn(str(context.exception), "[msprobe] 非法文件路径: ")
-
- os.makedirs(self.config.tls_path)
- with open(os.path.join(self.config.tls_path, "client.key"), 'w') as file:
- file.write("1")
- with open(os.path.join(self.config.tls_path, "client.crt"), 'w') as file:
- file.write("1")
- with open(os.path.join(self.config.tls_path, "ca.crt"), 'w') as file:
- file.write("1")
- with open(os.path.join(self.config.tls_path, "crl.pem"), 'w') as file:
- file.write("1")
- self.config._check_online_run_ut()
- shutil.rmtree(self.config.tls_path)
- self.config.tls_path = ""
-
- self.config.host = "invalid_host"
- with self.assertRaises(Exception) as context:
- self.config._check_online_run_ut()
- self.assertIn(str(context.exception), f"host: {self.config.host} is invalid.")
- self.config.host = "127.0.0.1"
-
- self.config.port = -1
- with self.assertRaises(Exception) as context:
- self.config._check_online_run_ut()
- self.assertIn(str(context.exception), f"port: {self.config.port} is invalid, port range 0-65535.")
- self.config.port = 6123
-
- # all config right
- self.config._check_online_run_ut()
-
-
class TestStatisticsConfig(unittest.TestCase):
def setUp(self):
@@ -365,60 +288,6 @@ class TestFreeBenchmarkCheckConfig(unittest.TestCase):
self.assertIn("The pert_mode when opening fix handler must be one of", str(mock_error.call_args))
-class TestRunUTConfig(unittest.TestCase):
-
- @patch('msprobe.pytorch.hook_module.utils.get_ops', return_value=['relu', 'gelu', 'conv2d'])
- def setUp(self, mock_get_ops):
- self.config = RunUTConfig({
- "white_list": ["relu"],
- "black_list": ["gelu"]
- })
-
- def test_check_filter_list_config_invalid_type(self):
- with self.assertRaises(Exception) as context:
- RunUTConfig.check_filter_list_config(Const.WHITE_LIST, "not_a_list")
- self.assertIn("must be a list type", str(context.exception))
-
- def test_check_filter_list_element_config_invalid_type(self):
- with self.assertRaises(Exception) as context:
- RunUTConfig.check_filter_list_config("white_list", [1, 1])
- self.assertIn("All elements in ", str(context.exception))
-
- def test_check_filter_list_config_invalid_item(self):
- with self.assertRaises(Exception) as context:
- RunUTConfig.check_filter_list_config("white_list", ["api1"])
- self.assertIn("Invalid api in white_list:", str(context.exception))
-
- @patch('os.path.exists', return_value=False)
- def test_check_error_data_path_config_not_exist(self, mock_exists):
- with self.assertRaises(Exception) as context:
- RunUTConfig.check_error_data_path_config("./invalid_path")
- self.assertIn("does not exist", str(context.exception))
-
- @patch('os.path.exists', return_value=False)
- def test_check_nfs_path_config_not_exist(self, mock_exists):
- with self.assertRaises(Exception) as context:
- RunUTConfig.check_nfs_path_config("./invalid_nfs")
- self.assertIn("[msprobe] 非法文件路径:", str(context.exception))
-
- @patch('os.path.exists', return_value=False)
- def test_check_tls_path_config_not_exist(self, mock_exists):
- with self.assertRaises(Exception) as context:
- RunUTConfig.check_tls_path_config("./invalid_tls")
- self.assertIn("[msprobe] 非法文件路径:", str(context.exception))
-
- def test_check_run_ut_config(self):
- with patch.object(RunUTConfig, 'check_filter_list_config') as mock_filter, \
- patch.object(RunUTConfig, 'check_error_data_path_config') as mock_error, \
- patch.object(RunUTConfig, 'check_nfs_path_config') as mock_nfs, \
- patch.object(RunUTConfig, 'check_tls_path_config') as mock_tls:
- self.config.check_run_ut_config()
- mock_filter.assert_called()
- mock_error.assert_called()
- mock_nfs.assert_called()
- mock_tls.assert_called()
-
-
class TestGradToolConfig(unittest.TestCase):
def setUp(self):
self.level_adp = {"L1": None, "L2": None}
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py
index f9d5744a957eaf8b4fefbe737cfeb5866c961f6e..d1419ab12caf47bd10c7aabb7991ffcc694b8f5d 100644
--- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py
@@ -18,7 +18,6 @@ from unittest.mock import MagicMock, patch
from msprobe.pytorch.pytorch_service import PytorchService
from msprobe.core.common.utils import Const
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
-from msprobe.pytorch.attl_manager import ATTLManager
from msprobe.pytorch.hook_module.hook_module import HOOKModule
@@ -29,7 +28,6 @@ class TestPytorchService(unittest.TestCase):
self.config.rank = []
self.config.level = Const.LEVEL_MIX
self.config.task = Const.STATISTICS
- self.config.online_run_ut_recompute = False
with patch('msprobe.core.service.build_data_collector'):
self.service = PytorchService(self.config)
@@ -37,8 +35,6 @@ class TestPytorchService(unittest.TestCase):
self.service.logger = MagicMock()
self.service.data_collector = MagicMock()
self.service.module_processor = MagicMock()
- self.service.attl_manager = MagicMock(spec=ATTLManager)
- self.service.attl_manager.attl = MagicMock()
self.service.api_register = MagicMock()
def test_framework_type(self):
@@ -56,12 +52,10 @@ class TestPytorchService(unittest.TestCase):
self.assertIsNotNone(service.logger)
self.assertIsNotNone(service.api_register)
self.assertIsNotNone(service.module_processor)
- self.assertIsNotNone(service.attl_manager)
self.assertIsNotNone(service.hook_manager)
def test_register_hook(self):
self.service._register_hook()
- self.service.attl_manager.attl_init.assert_called_once()
@patch('msprobe.pytorch.pytorch_service.register_optimizer_hook')
def test_register_hook_mix_level(self, mock_register_opt):
@@ -93,24 +87,7 @@ class TestPytorchService(unittest.TestCase):
self.assertTrue(self.service.module_processor.enable_module_dump)
- @patch('msprobe.pytorch.pytorch_service.torch_version_above_or_equal_2', new=True)
- @patch('msprobe.pytorch.pytorch_service.run_ut_dispatch')
- def test_run_ut_dispatch(self, mock_run_ut):
- status = True
- self.service._run_ut_dispatch(status)
- mock_run_ut.assert_called_once_with(
- self.service.attl_manager.attl,
- status,
- self.config.online_run_ut_recompute
- )
-
- @patch('msprobe.pytorch.pytorch_service.torch_version_above_or_equal_2', new=False)
- @patch('msprobe.pytorch.pytorch_service.run_ut_dispatch')
- def test_run_ut_dispatch_torch_version_below_2(self, mock_run_ut):
- status = True
- self.service._run_ut_dispatch(status)
- mock_run_ut.assert_not_called()
-
+
@patch.object(HOOKModule, 'reset_module_stats')
@patch.object(ModuleProcesser, 'reset_module_stats')
def test_reset_status(self, mock_reset_module_processor, mock_reset_hook_module):
@@ -119,60 +96,8 @@ class TestPytorchService(unittest.TestCase):
mock_reset_module_processor.assert_called_once()
self.service.data_collector.reset_status.assert_called_once()
- @patch('msprobe.pytorch.pytorch_service.torch_version_above_or_equal_2', new=True)
- @patch('msprobe.pytorch.pytorch_service.run_ut_dispatch')
- def test_start_with_online_run_ut(self, mock_run_ut):
- self.service.config.online_run_ut = True
- self.service.data_collector.data_processor.is_terminated = False
- model_mock = MagicMock()
-
- self.service.start(model=model_mock)
-
- mock_run_ut.assert_called_once_with(
- self.service.attl_manager.attl,
- True,
- self.config.online_run_ut_recompute
- )
-
- @patch('msprobe.pytorch.pytorch_service.torch_version_above_or_equal_2', return_value=True)
- @patch('msprobe.pytorch.pytorch_service.run_ut_dispatch')
- def test_stop_with_online_run_ut(self, mock_run_ut, mock_version):
- self.service.config.online_run_ut = True
- self.service.current_iter = 1
- self.service.current_rank = 0
- self.service.attl_manager.attl = MagicMock()
- self.service.stop()
-
- mock_run_ut.assert_called_once_with(
- self.service.attl_manager.attl,
- False,
- self.config.online_run_ut_recompute
- )
-
+
def test_register_module_hook(self):
self.service.model = MagicMock()
self.service._register_module_hook()
self.service.module_processor.register_module_hook.assert_called_once()
-
- @patch('msprobe.pytorch.pytorch_service.torch_version_above_or_equal_2', new=True)
- @patch('msprobe.pytorch.pytorch_service.run_ut_dispatch')
- def test_run_ut_dispatch_with_recompute(self, mock_run_ut):
- self.service.attl_manager.attl = None
- self.service.config.online_run_ut_recompute = True
- status = True
- self.service._run_ut_dispatch(status)
- mock_run_ut.assert_called_once_with(
- self.service.attl_manager.attl,
- status,
- True
- )
-
- def test_attl_manager_interaction(self):
- self.service.config.online_run_ut = True
- self.service.data_collector.data_processor.is_terminated = False
- self.service.start(model=MagicMock())
- self.service.attl_manager.attl_init.assert_called_once()
-
- self.service.data_collector.data_processor.is_terminated = True
- self.service.start()
- self.service.attl_manager.attl_stop.assert_called_once()