diff --git a/AscendIE/TorchAIE/built-in/foundation/ChatGLM2-6B/compile_model.py b/AscendIE/TorchAIE/built-in/foundation/ChatGLM2-6B/compile_model.py index ae09673c85677e179f225dc1a96fd3802be819e6..459ce1aa282ae25801eea0f61791e857aead814d 100644 --- a/AscendIE/TorchAIE/built-in/foundation/ChatGLM2-6B/compile_model.py +++ b/AscendIE/TorchAIE/built-in/foundation/ChatGLM2-6B/compile_model.py @@ -16,6 +16,8 @@ import sys from transformers import AutoTokenizer, AutoModel import torch +import torch_aie +from torch_aie import _enums import numpy as np import argparse @@ -28,14 +30,18 @@ def main(): required=False, help='npu device') parser.add_argument('--need_trace', default="true", required=False, help='If you have traced the model before then set false') + parser.add_argument('--need_compile', default="true", + required=False, help='If you have compiled the model before then set false') args = parser.parse_args() device = args.device model_path = args.pretrained_model need_trace = args.need_trace + need_compile = args.need_compile model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torchscript=True).float() model.eval() + torch_aie.set_device(device) # stage1: model trace if need_trace == "true": print("===================== start to trace model ==========================") @@ -54,6 +60,46 @@ def main(): torch.jit.save(traced_model, traced_model_path) print("===================== model trace success ==========================") + # stage2: model compile + if need_compile == "true": + ## load origin traced model + traced_model_path = "./chatglm2_6b_batch_1_traced.pt" + try: + traced_model = torch.jit.load(traced_model_path) + except Exception as e: + print("load model failed, please trace first.") + + ## set compile config + inputs = [] + max_seqlen = 10000 + input0_min_shape = (1, 1) + input0_max_shape = (1, max_seqlen) + input1_min_shape = (1, 1) + input1_max_shape = (1, max_seqlen) + input2_min_shape = (1, 1) + input2_max_shape = (1, max_seqlen) + input3_min_shape = (1, 2, 0, 1, 2, 128) + input3_max_shape = (28, 2, max_seqlen, 1, 2, 128) + + inputs.append(torch_aie.Input(min_shape = input0_min_shape, max_shape = input0_max_shape, dtype = torch.int64)) + inputs.append(torch_aie.Input(min_shape = input1_min_shape, max_shape = input1_max_shape, dtype = torch.int64)) + inputs.append(torch_aie.Input(min_shape = input2_min_shape, max_shape = input2_max_shape, dtype = torch.int64)) + inputs.append(torch_aie.Input(min_shape = input3_min_shape, max_shape = input3_max_shape, dtype = torch.float32)) + + ## compile + print("===================== start to compile model ==========================") + compiled_module = torch_aie.compile( + traced_model, + inputs=inputs, + precision_policy=_enums.PrecisionPolicy.FP32, + allow_tensor_replace_int=True, + soc_version="Ascend910B4" # 可以为Ascend910B3或者Asend910B4,具体根据使用的环境决定。 + ) + print("===================== model compile success ==========================") + ## save compiled result + aie_model_path = "./chatglm2_6b_batch_1_compiled.ts" + compiled_module.save(aie_model_path) + print("===================== save compiled model success ======================") if __name__ == '__main__': diff --git a/AscendIE/TorchAIE/built-in/foundation/ChatGLM2-6B/example.py b/AscendIE/TorchAIE/built-in/foundation/ChatGLM2-6B/example.py index ac1b0be2edff3302add1e28946bdcc547ff3c694..d25f2e986362ad5fe686efe42ca0b1c14b2e19cb 100644 --- a/AscendIE/TorchAIE/built-in/foundation/ChatGLM2-6B/example.py +++ b/AscendIE/TorchAIE/built-in/foundation/ChatGLM2-6B/example.py @@ -34,7 +34,7 @@ def signal_handler(signal, frame): def parse_arg(): parser = argparse.ArgumentParser() - parser.add_argument("--device", default="npu", help="cpu/npu") + parser.add_argument("--device", default=0, type=int, help="npu device") args = parser.parse_args() return args @@ -44,12 +44,10 @@ def main(): args = parse_arg() device = args.device print("device:", device) - aie_model = None - if device == "npu": - torch_aie.set_device(0) - aie_model_path = "./chatglm2_6b_batch_1_compiled.ts" - aie_model = torch.jit.load(aie_model_path) - aie_model.eval() + torch_aie.set_device(device) + aie_model_path = "./chatglm2_6b_batch_1_compiled.ts" + aie_model = torch.jit.load(aie_model_path) + aie_model.eval() print("欢迎使用 ChatGLM2-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") while True: diff --git a/AscendIE/TorchAIE/built-in/foundation/ChatGLM2-6B/readme.md b/AscendIE/TorchAIE/built-in/foundation/ChatGLM2-6B/readme.md index 7a3e8117c859b0cba73c6818916109e7a6064785..ffcd112788da0055605342cf879703ecfc0b2001 100644 --- a/AscendIE/TorchAIE/built-in/foundation/ChatGLM2-6B/readme.md +++ b/AscendIE/TorchAIE/built-in/foundation/ChatGLM2-6B/readme.md @@ -39,9 +39,10 @@ ChatGLM2-6B 是开源中英双语对话模型 ChatGLM-6B 的第二代版本, | ------------------------------------------------------------ | ------- | ------------------------------------------------------------ | | 固件与驱动 | 23.0.0| | CANN | 7.0.0 B050 | - | - | Python | 3.9.0 | - | + | Python | 3.10.0 | - | | PyTorch | 2.1.0 | - | | Torch_AIE | 6.3.rc2 | + | 芯片类型 | Ascend910B3/ Ascend910B4 | # 快速上手 @@ -62,11 +63,8 @@ ChatGLM2-6B 是开源中英双语对话模型 ChatGLM-6B 的第二代版本, ``` pip3 install -r requirement.txt ``` - 需要安装pt插件的python wheel(可根据代码仓中的readme.md操作) 和统一接口的run包。 - 参考: - https://gitee.com/ascend/ascend-inference-ptplugin.git - 这个时候我们可以通过命令`pip show torch`找到torch的目录, 比如'/usr/local/python3/lib/python3.9/site-packages/torch', 这个路径我们定义为${TORCH_ROOT_PATH}, 后续C++编译中需要用到。 + 这个时候我们可以通过命令`pip show torch`找到torch的目录, 比如'/usr/local/python3/lib/python3.10/site-packages/torch', 这个路径我们定义为${TORCH_ROOT_PATH}, 后续C++编译中需要用到。 #### 安装推理引擎统一接口 @@ -84,7 +82,7 @@ ChatGLM2-6B 是开源中英双语对话模型 ChatGLM-6B 的第二代版本, tar -zxvf Ascend-cann-torch-aie-${version}-linux_aarch64.tar.gz pip3 install torch-aie-${version}-linux_aarch64.whl ``` - 这个时候我们可以通过`pip show torch_aie`找到torch_aie的目录, 比如'/usr/local/python3/lib/python3.9/site-packages/torch_aie', 这个路径我们定义为${TORCH_AIE_PATH}, 后续C++编译中需要用到。 + 这个时候我们可以通过`pip show torch_aie`找到torch_aie的目录, 比如'/usr/local/python3/lib/python3.10/site-packages/torch_aie', 这个路径我们定义为${TORCH_AIE_PATH}, 后续C++编译中需要用到。 @@ -115,21 +113,33 @@ ChatGLM2-6B 是开源中英双语对话模型 ChatGLM-6B 的第二代版本, 注意:如果设置的device_id不为0, 那么需要做一下操作。 ``` 1 `sed -i 's/npu:0/npu:${device_id}/' model/modeling_chatglm.py ` 将device_id换成自己定义的id. -2 将 run.sh里头的 `./sample 0` 替换为 `./sample ${device_id}` -3. 将 example.py中49行set_device(0)改为具体的device_id. +2 跑 compile_model中加入 --device {device_id}` +3. 跑example.py加入 --device ${device_id}. ``` -1. trace模型与模型编译。 +### trace模型与模型编译。 + + 使用torch aie将模型源码,trace为pt文件,再通过pt插件转换为ts模型。 + 我们可选直接用python进行一体化配置,或者结合python和C++进行编译。 + #### python一体化流程 - 使用torch aie将模型源码,trace为pt文件。 ``` - python3.9 compile_model.py --device 0 + python3.10 compile_model.py --device 0 ``` compile_model的参数和默认值如下 ``` --device 0 \ # 环境使用的device_id --pretrained_model ./model/ # 源码和权重文件落盘位置 --need_trace true # 是否需要trace + --need_compile true # 是否需要compile + ``` + + #### 结合C++的方式进行compile + 1. trace模型 ``` + python3.10 compile_model.py --need_compile=false + ``` + 2. 用C++compile 模型。 + 模型编译`compile`文件夹内容如下: ```shell ├── compile @@ -149,13 +159,16 @@ ChatGLM2-6B 是开源中英双语对话模型 ChatGLM-6B 的第二代版本, ``` - -2. 开始对话验证。 + ### 对话验证。 ``` python3 example.py ``` + example.py的参数和默认值如下 + ``` + --device 0 \ # 环境使用的device_id + ``` -3. 最后对话的效果如下 + 最后对话的效果如下 ``` 欢迎使用 ChatGLM2-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序