diff --git a/tutorials/source_zh_cn/index.rst b/tutorials/source_zh_cn/index.rst index c7668dd8271cbfe9c2596ac3de46133407d528bd..0889bc1948576c8e2e74b0a8a9b3f08ec3852988 100644 --- a/tutorials/source_zh_cn/index.rst +++ b/tutorials/source_zh_cn/index.rst @@ -92,10 +92,6 @@ MindSpore教程 model_infer/introduction model_infer/ms_infer/ms_infer_model_infer - model_infer/ms_infer/ms_infer_network_develop - model_infer/ms_infer/ms_infer_parallel_infer - model_infer/ms_infer/ms_infer_quantization - model_infer/ms_infer/ms_infer_model_serving_infer model_infer/lite_infer/overview .. toctree:: diff --git a/tutorials/source_zh_cn/model_infer/introduction.md b/tutorials/source_zh_cn/model_infer/introduction.md index 8e23672fded9cefeea5a9c29728d20274e8abdf2..192448362c9d76f14e17055c9ba21656c9ee0427 100644 --- a/tutorials/source_zh_cn/model_infer/introduction.md +++ b/tutorials/source_zh_cn/model_infer/introduction.md @@ -1,4 +1,4 @@ -# MindSpore推理 +# MindSpore推理概述 [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/tutorials/source_zh_cn/model_infer/introduction.md) diff --git a/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_model_infer.md b/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_model_infer.rst similarity index 52% rename from tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_model_infer.md rename to tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_model_infer.rst index 8e1658e42c8b50c248fbc23e39f0ee6ff0c90f20..802b1b8e0bfefb6ceaabf74123a3509b3a134084 100644 --- a/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_model_infer.md +++ b/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_model_infer.rst @@ -1,8 +1,21 @@ -# MindSpore大语言模型带框架推理 +MindSpore大语言模型带框架推理 +============================= -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_model_infer.md) +.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg + :target: https://gitee.com/mindspore/docs/blob/master/tutorials/source_zh_cn/parallel/optimize_technique.rst + :alt: 查看源文件 -## 特性背景 +.. toctree:: + :maxdepth: 1 + :hidden: + + ms_infer_network_develop + ms_infer_parallel_infer + ms_infer_quantization + ms_infer_model_serving_infer + +特性背景 +-------- 2022年末,OpenAI发布了ChatGPT大语言模型,为人工智能带来了一个新的研究方向,即基于Transformers结构的大语言模型,其展现了超过人们预期的AI能力,在多项测试中取得较好成绩,快速成为人工智能的研究焦点。 @@ -12,7 +25,8 @@ - 针对大语言模型推理成本高的问题,MindSpore框架提供了大语言模型推理能力,结合当前主流大语言模型的特点,深度优化大语言模型部署和推理,实现模型推理成本最优。 -## 模型原理 +模型原理 +-------- 在了解MindSpore大语言模型推理的能力之前,让我们先了解一下当前主流大语言模型是如何实现让人惊叹的智能水平的,下面我们将以当前最常见的文本生成类大语言模型为例子,简单介绍大语言模型的推理原理,了解AI模型是如何通过计算,完成和人对话、总结文章中心思想等复杂任务的。 @@ -25,46 +39,40 @@ 在实际文本处理场景中,语言是复杂多变的,因此很难直接找到两个句子的直接相关性,大语言模型技术通常会采用单词化的方法,即将“中国的面积有”分解为多个常见的单词组合,例如“中国”、“的”、“面积”、“有”,这种做法不仅可以更好地应对文本差异带来的影响,如"中国的面积是"和"中国的面积有"两个短语相似度可能为0,而["中国","的","面积","是"]和["中国","的","面积","有"]两个组合的相似度就可以认为有75%,可以有效帮助大语言模型识别出这类文本差异,这种技术我们通常称为tokenize,即把一段文本分解为一组token(通常是单词和标点符号之类的元素)的组合表示。大语言模型完成生成一句话的过程就是根据当前的token组合信息,每一轮推理出下一个token,并和之前的token组合到一起,形成新一轮的输入,反复迭代每次生成一个单词,逐步完成整段文本的生成。下表简单描述一下大语言模型推理的例子: 用户输入:中国的首都 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
推理迭代推理输入输入向量推理结果
1中国的首都[中国, 的, 首都]北京
2中国的首都北京[中国, 的, 首都, 北京]
3中国的首都北京真[中国, 的, 首都, 北京, 真]美丽
4中国的首都北京真美丽[中国, 的, 首都, 北京, 真, 美丽]END
+ +.. list-table:: 推理示例 + :header-rows: 1 + + * - 推理迭代 + - 推理输入 + - 输入向量 + - 推理结果 + * - 1 + - 中国的首都 + - [中国, 的, 首都] + - 北京 + * - 2 + - 中国的首都北京 + - [中国, 的, 首都, 北京] + - 真 + * - 3 + - 中国的首都北京真 + - [中国, 的, 首都, 北京, 真] + - 美丽 + * - 4 + - 中国的首都北京真美丽 + - [中国, 的, 首都, 北京, 真, 美丽] + - END 可以看到,每一轮迭代,实际上大语言模型会根据当前语境推理出下一个token,与前面的语句拼装成下一轮迭代的输入,通过多轮迭代,当遇到生成的token是END这个特殊token时,模型认为推理结束,将结果返回给用户。 -## 关键步骤 +关键步骤 +-------- MindSpore大语言模型推理为用户提供了“开箱即用”的大语言模型部署和推理能力,用户能够利用MindSpore提供的大语言模型相关API,快速部署自己的大语言模型,并根据模型特点进行相应优化,实现最优性价比,为实际生产生活带来大语言模型能力。下图是利用MindSpore大语言模型推理特性进行模型推理的关键步骤图: -![llm-infer-flow](./images/llm_infer_flow.png) +.. figure:: ./images/llm_infer_flow.png + :alt: llm-infer-flow 1. **权重准备**:权重数据是大语言模型的智能核心,因此部署模型的第一步就是获取和准备好对应模型的权重文件。 2. **模型加载**:模型推理时,根据使用的不同优化技术,模型结构会有一定的差异,因此需要根据模型网络结构将模型的主干网络构建出来,方便后续进行推理。 @@ -73,122 +81,129 @@ MindSpore大语言模型推理为用户提供了“开箱即用”的大语言 5. **模型推理**:通过输入的数据进行模型推理,通常会返回语句中下一个token的概率分布。 6. **推理后处理**:根据模型推理的结果,计算出下一个token,将token转换成文本返回给用户,同时如果推理没有结束,将token拼装成下一轮推理的输入继续推理。 -## 主要特性 +主要特性 +-------- MindSpore大语言模型为了能够实现最优的性价比,针对大语言模型网络的特性,进行了多项深度优化,其中主要包含以下特性: - **全量/增量推理**:大语言模型的核心网络结构是以transfomer为主的自注意力机制,每一轮迭代都要计算所有token的注意力分数,而实际上相同的token序列计算注意力分数时key和value结果是相同的,即["中国","的","面积","是"]的key和value可以理解为是由["中国","的","面积"]和["是"]拼接而成的,因此可以通过将前面已经计算的序列的key和value值缓存起来,从而减少下一轮迭代推理过程的计算量,这种技术通常被称为KVCache优化。结合大语言模型推理的全过程可以发现,在N和N+1轮的两次连续迭代中,其中N+1轮可以完全复用N轮的key和value值,因为前N个序列是一致的,真正需要计算key和value的只有N+1轮的第一个token,这样我们可以将模型推理分为以下两个阶段: - - **全量推理**:用户输入的第一轮迭代,此时用户给出的长度为N的语句,N的长度和内容都无法预测,需要计算全部key和value的值,成为全量推理。 + - **全量推理**:用户输入的第一轮迭代,此时用户给出的长度为N的语句,N的长度和内容都无法预测,需要计算全部key和value的值,成为全量推理。 - - **增量推理**:完成第一轮迭代计算后,前一轮迭代语句的key和value值已经缓存在KVCache中,此时只需要额外计算最近一个token对应的key和value值,并与缓存的结果拼接起来计算注意力分数,成为增量推理。 + - **增量推理**:完成第一轮迭代计算后,前一轮迭代语句的key和value值已经缓存在KVCache中,此时只需要额外计算最近一个token对应的key和value值,并与缓存的结果拼接起来计算注意力分数,成为增量推理。 - **Attention优化**:大语言模型网络结构最主要的计算是对于Attention的计算,由于当前主流模型的Attention的size比较大(通常4K或以上),模型推理的整个过程性能强依赖于Attention计算的性能,因此当前有很多研究在关注如何优化Attention计算性能,其中比较主流的包括Flash Attention和Page Attention技术。 - - **Flash Attention**:Attention计算中会存在两个大矩阵相乘(4K大小),实际计算会将大矩阵分解为多个芯片能够计算的小矩阵单元进行计算,由于芯片的最小级的缓存大小限制,需要不断地将待计算数据在缓存和主存间搬入搬出,导致计算资源实际无法充分利用,因此当前主流芯片下,Attention计算实际上是带宽bound。Flash Attention技术将原本Attention进行分块,使得每一块计算都能够在芯片上独立计算完成,避免了在计算Key和Value时多次数据的搬入和搬出,从而提升Attention计算性能,具体可以参考[Flash Attention](https://arxiv.org/abs/2205.14135)。 + - **Flash Attention**:Attention计算中会存在两个大矩阵相乘(4K大小),实际计算会将大矩阵分解为多个芯片能够计算的小矩阵单元进行计算,由于芯片的最小级的缓存大小限制,需要不断地将待计算数据在缓存和主存间搬入搬出,导致计算资源实际无法充分利用,因此当前主流芯片下,Attention计算实际上是带宽bound。Flash Attention技术将原本Attention进行分块,使得每一块计算都能够在芯片上独立计算完成,避免了在计算Key和Value时多次数据的搬入和搬出,从而提升Attention计算性能,具体可以参考 `Flash Attention `_。 - - **Page Attention显存优化**:标准的Flash Attention每次会读取和保存整个输入的Key和Value数据,这种方式虽然比较简单,但是会造成较多的资源浪费,如“中国的首都”和“中国的国旗”,都有共同的“中国的”作为公共前缀,其Attention对应的Key和Value值实际上是一样的,标准Flash Attention就需要存两份Key和Value,导致显存浪费。Page Attention基于Linux操作系统页表原理对KVCache进行优化,按照特定大小的块来存储Key和Value的数据,将上面例子中的Key和Value存储为“中国”、“的”、“首都”、“国旗”一共四份Key和Value数据,相比原来的六份数据,有效地节省了显存资源。在服务化的场景下,更多空闲显存可以让模型推理的batch更大,从而获得更高的吞吐量,具体可以参考[Page Attention](https://arxiv.org/pdf/2309.06180)。 + - **Page Attention显存优化**:标准的Flash Attention每次会读取和保存整个输入的Key和Value数据,这种方式虽然比较简单,但是会造成较多的资源浪费,如“中国的首都”和“中国的国旗”,都有共同的“中国的”作为公共前缀,其Attention对应的Key和Value值实际上是一样的,标准Flash Attention就需要存两份Key和Value,导致显存浪费。Page Attention基于Linux操作系统页表原理对KVCache进行优化,按照特定大小的块来存储Key和Value的数据,将上面例子中的Key和Value存储为“中国”、“的”、“首都”、“国旗”一共四份Key和Value数据,相比原来的六份数据,有效地节省了显存资源。在服务化的场景下,更多空闲显存可以让模型推理的batch更大,从而获得更高的吞吐量,具体可以参考 `Page Attention `_。 - **模型量化**:MindSpore大语言模型推理支持通过量化技术减小模型体积,提供了A16W8、A16W4、A8W8量化以及KVCache量化等技术,减少模型资源占用,提升推理吞吐量。 -## 推理教程 +推理教程 +-------- 本章节将会结合当前主流的Qwen2开源大语言模型,演示如何通过MindSpore大语言模型推理提供的能力,逐步构建一个可以端到端进行文本生成的例子。 -> 由于Qwen2模型也有多个版本和配置,本文主要基于Qwen2-7B-Instrcut模型进行说明。 +.. note:: + + 由于Qwen2模型也有多个版本和配置,本文主要基于Qwen2-7B-Instrcut模型进行说明。 -### 环境准备 +环境准备 +~~~~~~~~ MindSpore大语言模型带框架推理主要依赖MindSpore开源软件,用户在使用前,需要先安装MindSpore的Python包,建议使用conda虚拟环境运行。可以执行如下命令简单安装: -```shell -export PYTHON_ENV_NAME=mindspore-infer-py311 -conda create -n ${PYTHON_ENV_NAME} python=3.11 -conda activate ${PYTHON_ENV_NAME} -pip install mindspore -``` +.. code:: shell -同时,用户也可以参考官方安装文档来安装自己环境适配的Python包,具体见[MindSpore安装](https://www.mindspore.cn/install)。 + export PYTHON_ENV_NAME=mindspore-infer-py311 + conda create -n ${PYTHON_ENV_NAME} python=3.11 + conda activate ${PYTHON_ENV_NAME} + pip install mindspore + +同时,用户也可以参考官方安装文档来安装自己环境适配的Python包,具体见 `MindSpore安装 `_。 由于MindSpore推理主要支持Ascend芯片环境上运行,还需要安装相应的Ascend开发环境,具体可以参考: -```shell -pip install ${ASCEND_HOME}/lib64/te-*.whl -pip install ${ASCEND_HOME}/lib64/hccl-*.whl -pip install sympy -``` +.. code:: shell + + pip install ${ASCEND_HOME}/lib64/te-*.whl + pip install ${ASCEND_HOME}/lib64/hccl-*.whl + pip install sympy 如果用户要复用当前主流的LLM模型的tokenizer能力,可以安装Transformers软件包: -```shell -pip install transformers -``` +.. code:: shell -如果用户需要使用模型量化能力提升模型推理性能,还需要安装mindspore_gs包,具体可以参考[MindSpore GoldenStick安装](https://www.mindspore.cn/golden_stick/docs/zh-CN/master/install.html)。 + pip install transformers -### 权重准备 +如果用户需要使用模型量化能力提升模型推理性能,还需要安装mindspore_gs包,具体可以参考 `MindSpore GoldenStick安装 `_。 + +权重准备 +~~~~~~~~ 权重准备主要是获取大语言模型的权重文件。同时,通常每一个大语言模型都有自己对应的token列表,表示该模型支持的单词全集,因此,除了模型的权重外,还需要获取其对应的tokenizer映射。MindSpore当前已经支持直接加载safetensor的权重文件,用户可以直接下载Hugging Face官网上的模型权重文件。 对于Qwen2大语言模型,建议用户直接使用Hugging Face官方网站提供的预训练权重文件与tokenizer映射,用户可以简单地使用下面的命令进行权重下载: -```shell -git lfs install -git clone https://huggingface.co/Qwen/Qwen2-7B-Instruct -``` +.. code:: shell + + git lfs install + git clone https://huggingface.co/Qwen/Qwen2-7B-Instruct 下载完成后,相关目录下应该显示如下文件树结构: -```shell -ls -|- config.json -|- LICENSE -|- merges.txt -|- model-00001-of-00004.safetensors -|- model-00002-of-00004.safetensors -|- model-00003-of-00004.safetensors -|- model-00004-of-00004.safetensors -|- model.safetensors.index.json -|- README.md -|- tokenizer_config.json -|- tokenizer.json -|- vocab.json -``` - -### 模型构建 +.. code:: shell + + ls + |- config.json + |- LICENSE + |- merges.txt + |- model-00001-of-00004.safetensors + |- model-00002-of-00004.safetensors + |- model-00003-of-00004.safetensors + |- model-00004-of-00004.safetensors + |- model.safetensors.index.json + |- README.md + |- tokenizer_config.json + |- tokenizer.json + |- vocab.json + +模型构建 +~~~~~~~~ 首先用户需要构建模型并加载权重,执行以下代码: -```python -import os -import mindspore as ms -from qwen2 import Qwen2Config, Qwen2ForCausalLM, CacheManager -from mindspore import Tensor, mint +.. code:: python -# set mindspore context and envs -os.environ["MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST"] = "PagedAttention" + import os + import mindspore as ms + from qwen2 import Qwen2Config, Qwen2ForCausalLM, CacheManager + from mindspore import Tensor, mint -ms.set_context(infer_boost="on") -ms.set_context(mode=ms.context.PYNATIVE_MODE) + # set mindspore context and envs + os.environ["MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST"] = "PagedAttention" -model_path = "/path/to/model" -input_str = ["I love Beijing, because", "Hello, Qwen2"] -batch_size = len(input_str) -max_new_tokens = 64 -block_size = 128 -max_seq_lens = block_size * 10 -block_num = (max_seq_lens * batch_size) // block_size + ms.set_context(infer_boost="on") + ms.set_context(mode=ms.context.PYNATIVE_MODE) -config = Qwen2Config.from_json(model_path + "/config.json") + model_path = "/path/to/model" + input_str = ["I love Beijing, because", "Hello, Qwen2"] + batch_size = len(input_str) + max_new_tokens = 64 + block_size = 128 + max_seq_lens = block_size * 10 + block_num = (max_seq_lens * batch_size) // block_size -model = Qwen2ForCausalLM(config) -# load weight -model.load_weight(model_path) + config = Qwen2Config.from_json(model_path + "/config.json") -cache_manager = CacheManager(config, block_num, block_size, batch_size) -``` + model = Qwen2ForCausalLM(config) + # load weight + model.load_weight(model_path) -其中,qwen2为模型的网络脚本(qwen2.py),需要和当前脚本在同一个目录下,可以参考[从零构建大语言模型推理网络](./ms_infer_network_develop.md)。用户也可以使用其他的网络脚本,但是需要修改相应的模型接口。 + cache_manager = CacheManager(config, block_num, block_size, batch_size) + +其中,qwen2为模型的网络脚本(qwen2.py),需要和当前脚本在同一个目录下,可以参考 `从零构建大语言模型推理网络 <./ms_infer_network_develop.md>`_。用户也可以使用其他的网络脚本,但是需要修改相应的模型接口。 脚本中第一步是设置mindspore相关环境变量,包括: @@ -212,158 +227,160 @@ cache_manager = CacheManager(config, block_num, block_size, batch_size) 根据以上参数对模型进行初始化,获得model和cache_manager对象。 -### 模型推理 +模型推理 +~~~~~~~~ 模型构建好之后,用户就可以使用模型对象来进行文本生成,实现如自助客服、智能问答、聊天机器人等实际应用。但是应用的输入通常是一句语言的文本,无法直接作为模型的输入进行计算。因此,我们需要增加前处理和后处理的逻辑,将文本语言转换成模型能够识别的token数据,在完成推理计算后,再将token数据转换成文本语言。我们以一句简单的问答文本生成为例子,简单描述这个过程: - **前处理**:利用tokenizer的数据,将一句话分解为多个token id表示的list。此处,我们使用Transformers开源社区的的tokenizer。 - ```python - from transformers import AutoTokenizer + .. code:: python + + from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - input_str = ["I love Beijing, because", "Hello, Qwen2"] + input_str = ["I love Beijing, because", "Hello, Qwen2"] - input_ids = tokenizer(input_str)["input_ids"] + input_ids = tokenizer(input_str)["input_ids"] - print(input_ids) - ``` + print(input_ids) - 执行此Python代码,会打印如下输出: + 执行此Python代码,会打印如下输出: - ```shell - [【40, 2948, 26549, 11, 1576】, 【9707, 11, 1207, 16948, 17】] - ``` + .. code:: shell - 其中,【40, 2948, 26549, 11, 1576】对应"I love Beijing, because"的单词序列,40表示I对应的token,2948表示love对应的token,26549表示Beijing对应的token,11表示逗号加空格对应的token,1576表示because对应的token,这个格式可以直接传给模型进行推理。同理【9707, 11, 1207, 16948, 17】对应‘Hello, Qwen2’的输入序列。此处采用一次传入2个请求的batch计算进行演示。 + [[40, 2948, 26549, 11, 1576], [9707, 11, 1207, 16948, 17]] + + 其中,[40, 2948, 26549, 11, 1576]对应"I love Beijing, because"的单词序列,40表示I对应的token,2948表示love对应的token,26549表示Beijing对应的token,11表示逗号加空格对应的token,1576表示because对应的token,这个格式可以直接传给模型进行推理。同理[9707, 11, 1207, 16948, 17]对应"Hello, Qwen2"的输入序列。此处采用一次传入2个请求的batch计算进行演示。 - **整网计算**:传入当前输入token的数据和配置,让模型对象通过多轮计算迭代推理出每轮的token结果。为了代码更加简洁,可以将迭代推理封装到如下generate函数中: - ```python - from typing import List - from mindspore import ops, mint, Tensor, dtype - from qwen2 import Qwen2Config, Qwen2ModelInput, Qwen2ForCausalLM, CacheManager, sample + .. code:: python + + from typing import List + from mindspore import ops, mint, Tensor, dtype + from qwen2 import Qwen2Config, Qwen2ModelInput, Qwen2ForCausalLM, CacheManager, sample - def generate(model: Qwen2ForCausalLM, cache_manager: CacheManager, input_ids: List, max_new_tokens: int, max_seq_lens: int, eos_token_id: int): - batch_size = len(input_ids) - assert max_seq_lens >= max(map(len, input_ids)) + def generate(model: Qwen2ForCausalLM, cache_manager: CacheManager, input_ids: List, max_new_tokens: int, max_seq_lens: int, eos_token_id: int): + batch_size = len(input_ids) + assert max_seq_lens >= max(map(len, input_ids)) - cur = min(map(len, input_ids)) - is_prefill = True - it = 0 + cur = min(map(len, input_ids)) + is_prefill = True + it = 0 - decode_q_seq_lens = Tensor([1 for _ in range(batch_size)], dtype=dtype.int32) - decode_mask = ops.zeros((1, 1), dtype=config.param_dtype) - attn_mask = None - q_seq_lens = None + decode_q_seq_lens = Tensor([1 for _ in range(batch_size)], dtype=dtype.int32) + decode_mask = ops.zeros((1, 1), dtype=config.param_dtype) + attn_mask = None + q_seq_lens = None - while cur <= max_seq_lens and it < max_new_tokens: - batch_valid_length = Tensor([cur for _ in range(batch_size)], dtype=dtype.int32) - if is_prefill: - inp = Tensor([input_ids[i][:cur] for i in range(batch_size)], dtype=dtype.int32) - pos = mint.arange(cur).astype(dtype.int32) - block_tables, slot_mapping = cache_manager.step(0, cur) - attn_mask = ops.logical_not(ops.sequence_mask(pos + 1, cur)).astype(config.param_dtype) - q_seq_lens = None - else: - inp = Tensor([[input_ids[i][cur - 1]] for i in range(batch_size)], dtype=dtype.int32) - pos = Tensor([[cur - 1] for _ in range(batch_size)], dtype=dtype.int32).view(-1) - block_tables, slot_mapping = cache_manager.step(cur - 1, 1) - attn_mask = decode_mask - q_seq_lens = decode_q_seq_lens + while cur <= max_seq_lens and it < max_new_tokens: + batch_valid_length = Tensor([cur for _ in range(batch_size)], dtype=dtype.int32) + if is_prefill: + inp = Tensor([input_ids[i][:cur] for i in range(batch_size)], dtype=dtype.int32) + pos = mint.arange(cur).astype(dtype.int32) + block_tables, slot_mapping = cache_manager.step(0, cur) + attn_mask = ops.logical_not(ops.sequence_mask(pos + 1, cur)).astype(config.param_dtype) + q_seq_lens = None + else: + inp = Tensor([[input_ids[i][cur - 1]] for i in range(batch_size)], dtype=dtype.int32) + pos = Tensor([[cur - 1] for _ in range(batch_size)], dtype=dtype.int32).view(-1) + block_tables, slot_mapping = cache_manager.step(cur - 1, 1) + attn_mask = decode_mask + q_seq_lens = decode_q_seq_lens - model_input = Qwen2ModelInput( - input_ids=inp, - positions=pos, - batch_valid_length=batch_valid_length, - is_prefill=is_prefill, - attn_mask=attn_mask, - k_caches=cache_manager.k_caches, - v_caches=cache_manager.v_caches, - block_tables=block_tables, - slot_mapping=slot_mapping, - q_seq_lens=q_seq_lens - ) + model_input = Qwen2ModelInput( + input_ids=inp, + positions=pos, + batch_valid_length=batch_valid_length, + is_prefill=is_prefill, + attn_mask=attn_mask, + k_caches=cache_manager.k_caches, + v_caches=cache_manager.v_caches, + block_tables=block_tables, + slot_mapping=slot_mapping, + q_seq_lens=q_seq_lens + ) - logits = model(model_input) + logits = model(model_input) - next_tokens = sample(logits) + next_tokens = sample(logits) - for i in range(batch_size): - if cur >= len(input_ids[i]): - input_ids[i].append(int(next_tokens[i])) + for i in range(batch_size): + if cur >= len(input_ids[i]): + input_ids[i].append(int(next_tokens[i])) - cur += 1 - it += 1 - if is_prefill: - is_prefill = False + cur += 1 + it += 1 + if is_prefill: + is_prefill = False - for i in range(batch_size): - if eos_token_id in input_ids[i]: - eos_idx = input_ids[i].index(eos_token_id) - input_ids[i] = input_ids[i][: eos_idx + 1] + for i in range(batch_size): + if eos_token_id in input_ids[i]: + eos_idx = input_ids[i].index(eos_token_id) + input_ids[i] = input_ids[i][: eos_idx + 1] - return input_ids - ``` + return input_ids - 上面的generate函数模拟了大语言模型推理的迭代过程,其中核心步骤包括以下几个: + 上面的generate函数模拟了大语言模型推理的迭代过程,其中核心步骤包括以下几个: - 1. **模型输入准备**:准备模型推理需要的输入数据,构造Qwen2ModelInput对象,其主要的参数包括: + 1. **模型输入准备**:准备模型推理需要的输入数据,构造Qwen2ModelInput对象,其主要的参数包括: - **input_ids**:输入的词表id的list,每个batch一个list表示。 + **input_ids**:输入的词表id的list,每个batch一个list表示。 - **positions**:表示输入的词表在推理语句中的位置信息,主要用于rope旋转位置编码。 + **positions**:表示输入的词表在推理语句中的位置信息,主要用于rope旋转位置编码。 - **batch_valid_length**:表示当前推理的语句长度,主要是用于获取KVCache的KV值。通常是positions的值加1,投机推理场景下可能大于positions的值加1。 + **batch_valid_length**:表示当前推理的语句长度,主要是用于获取KVCache的KV值。通常是positions的值加1,投机推理场景下可能大于positions的值加1。 - **is_prefill**:是否是全量推理。全量推理需要计算多个KV值;增量推理通常可以复用上一轮计算的KV结果,只需要计算最后一个KV值。 + **is_prefill**:是否是全量推理。全量推理需要计算多个KV值;增量推理通常可以复用上一轮计算的KV结果,只需要计算最后一个KV值。 - **attn_mask**:用于注意力分数计算时隐藏掉不必要的信息,通常是一个上三角或者下三角的标准矩阵(有效值是1,其余是0)。 + **attn_mask**:用于注意力分数计算时隐藏掉不必要的信息,通常是一个上三角或者下三角的标准矩阵(有效值是1,其余是0)。 - **kv_caches**:KVCache对象,保存了所有计算的KV结果。 + **kv_caches**:KVCache对象,保存了所有计算的KV结果。 - **block_tables&slot_mapping**:表示当前推理词表使用的KVCache具体信息,block_tables表示每个batch当前使用的block,slot_mapping表示对应的单词在block中的具体位置。如block_tables=【2, 10】,slot_mapping=【1200】,block_size=128,表示当前推理使用了第2个和第10个block,当前单词用了第1200个block单元,即第10个block的第48个单元的KV值。 + **block_tables&slot_mapping**:表示当前推理词表使用的KVCache具体信息,block_tables表示每个batch当前使用的block,slot_mapping表示对应的单词在block中的具体位置。如block_tables=[2, 10],slot_mapping=[1200],block_size=128,表示当前推理使用了第2个和第10个block,当前单词用了第1200个block单元,即第10个block的第48个单元的KV值。 - **q_seq_lens**:表示注意力中query的长度,主要是PagedAttention算子使用。标准模型下值一般是1,投机推理场景下可能大于1。 + **q_seq_lens**:表示注意力中query的长度,主要是PagedAttention算子使用。标准模型下值一般是1,投机推理场景下可能大于1。 - 2. **模型计算**:调用主干模型网络启动模型计算逻辑,计算出下一个单词的概率分布。 + 2. **模型计算**:调用主干模型网络启动模型计算逻辑,计算出下一个单词的概率分布。 - 3. **采样结果**:通过sample采样计算获取下一个单词的id(此处使用argmax,即选择概率最大的单词)。 + 3. **采样结果**:通过sample采样计算获取下一个单词的id(此处使用argmax,即选择概率最大的单词)。 - 4. **更新下一个迭代输入**:更新下个迭代的词表list,进入下一个迭代。 + 4. **更新下一个迭代输入**:更新下个迭代的词表list,进入下一个迭代。 - 完成以上的迭代后,可以做一些优化,由于此处模型推理实现按照推理单词个数来结束,推理结果可能会被突然打断,因此可以通过tokenizer的断句词表id,将结果圈定在最后断句(如句号)的位置,提升文本结果的可读性。封装完成后,可以通过以下代码简单的调用单词生成过程: + 完成以上的迭代后,可以做一些优化,由于此处模型推理实现按照推理单词个数来结束,推理结果可能会被突然打断,因此可以通过tokenizer的断句词表id,将结果圈定在最后断句(如句号)的位置,提升文本结果的可读性。封装完成后,可以通过以下代码简单的调用单词生成过程: - ```python - output = generate( - model=model, - cache_manager=cache_manager, - input_ids=input_ids, - max_new_tokens=max_new_tokens, - eos_token_id=tokenizer.eos_token_id, - max_seq_lens=max_seq_lens - ) - ``` + .. code:: python + + output = generate( + model=model, + cache_manager=cache_manager, + input_ids=input_ids, + max_new_tokens=max_new_tokens, + eos_token_id=tokenizer.eos_token_id, + max_seq_lens=max_seq_lens + ) - **后处理**:根据网络推理的输出,利用tokenizer的反向能力,将token id的list转换成一句可理解的语句。 - ```python - result = [tokenizer.decode(a) for a in output] - print(result) - ``` + .. code:: python + + result = [tokenizer.decode(a) for a in output] + print(result) + + 执行此Python代码,会打印如下输出: - 执行此Python代码,会打印如下输出: + .. code:: shell - ```shell - I love Beijing, because it is a city that is constantly changing. I have been living here for 10 years and I have seen the city changes so much. ... - ``` + I love Beijing, because it is a city that is constantly changing. I have been living here for 10 years and I have seen the city changes so much. ... - 可以看到,将模型推理的token id翻译后,即是一句可以被正常人理解的语句,实际验证过程中,由于do_sample的随机性,每次推理会有一定的差异,但是结果的逻辑基本都是可以被理解的。 + 可以看到,将模型推理的token id翻译后,即是一句可以被正常人理解的语句,实际验证过程中,由于do_sample的随机性,每次推理会有一定的差异,但是结果的逻辑基本都是可以被理解的。 - 完整端到端样例可以参考[infer.py](https://gitee.com/mindspore/docs/blob/master/docs/sample_code/infer_code/qwen2/infer.py) + 完整端到端样例可以参考 `infer.py `_ 。 -### 模型并行 +模型并行 +~~~~~~~~ 对于模型参数比较多的大语言模型,如Llama2-70B、Qwen2-72B,由于其参数规模通常会超过一张GPU或者NPU的内存容量,因此需要采用多卡并行推理。MindSpore大语言模型推理支持将原始大语言模型切分成N份可并行的子模型,使其能够分别在多卡上并行执行,在实现超大模型推理同时,也利用多卡中更多的资源提升性能。MindFormers模型套件提供的模型脚本天然支持将模型切分成多卡模型执行。 @@ -379,13 +396,16 @@ cache_manager = CacheManager(config, block_num, block_size, batch_size) 为了更加清晰的描述模型并行计算的流程,本章基于最基础和最普遍的模型并行策略进行说明,用户可以通过以下几步来实现模型的并行适配: -1. **模型适配**:MindSpore大语言模型多卡运行时,通常使用模型并行,因此原始模型需要根据卡数进行切分,如[1024,4096]和[4096, 2048]矩阵乘法,可以切分成2个[1024,4096]和[4096, 1024]的矩阵乘法。而不同的切分可能带来不同的并行计算性能。对于Qwen、LLAMA这类大语言模型而言,其切分主要包含在Attention中query、key、value这些数据的linear操作上。 +1. **模型适配**:MindSpore大语言模型多卡运行时,通常使用模型并行,因此原始模型需要根据卡数进行切分,如[1024, 4096]和[4096, + 2048]矩阵乘法,可以切分成2个[1024, 4096]和[4096, + 1024]的矩阵乘法。而不同的切分可能带来不同的并行计算性能。对于Qwen、LLAMA这类大语言模型而言,其切分主要包含在Attention中query、key、value这些数据的linear操作上。 2. **权重适配**:除了模型结构的并行化改造外,由于模型计算中的权重也被切分了,因此在模型加载的时候,相关的权重也要进行切分,以尽量减少不必要权重加载占用显存。对于大语言模型而言,主要的权重都集中在embbeding和linear两个网络层中,因此权重加载的适配主要涉及这两个模块改造。 -3. **模型推理**:和单卡推理不同,多卡推理需要同时启动多个进程来并行进行推理,因此在启动模型推理时,相比于直接运行脚本,多卡推理需要一次运行多组相关进程。MindSpore框架为用户提供了msrun的并行运行工具,具体使用方法可以参考[构建可并行的大语言模型网络](./ms_infer_parallel_infer.md)。 +3. **模型推理**:和单卡推理不同,多卡推理需要同时启动多个进程来并行进行推理,因此在启动模型推理时,相比于直接运行脚本,多卡推理需要一次运行多组相关进程。MindSpore框架为用户提供了msrun的并行运行工具,具体使用方法可以参考 `构建可并行的大语言模型网络 <./ms_infer_parallel_infer.md>`_。 -### 模型量化 +模型量化 +~~~~~~~~ MindSpore大语言模型支持以下量化技术,来提升模型推理性能: @@ -401,14 +421,17 @@ MindSpore大语言模型支持以下量化技术,来提升模型推理性能 2. **模型推理**:加载标准模型,将模型网络进行量化改造(插入相应量化算子),加载量化后的权重,调用模型推理。 -具体模型量化的详细资料可以参考[模型量化](./quantization.md)。 +具体模型量化的详细资料可以参考 `模型量化 <./ms_infer_quantization>`_。 -## 高级用法 +高级用法 +-------- - **使用自定义算子优化模型推理** - MindSpore大语言模型推理支持用户自定义算子接入,以实现用户特定场景的算子优化,或者实现网络中的算子融合,用户可以通过简单的修改网络脚本的算子API来实现自定义算子的使能与关闭,具体可以参考[自定义算子](../../custom_program/operation/op_custom_ascendc.md)。 + MindSpore大语言模型推理支持用户自定义算子接入,以实现用户特定场景的算子优化,或者实现网络中的算子融合,用户可以通过简单的修改网络脚本的算子API来实现自定义算子的使能与关闭,具体可以参考 `自定义算子 <../../custom_program/operation/op_custom_ascendc.md>`_。 - **大语言模型离线推理** - 由于大语言模型体积巨大,因此MindSpore大语言模型推理推荐用户使用更灵活的在线推理(权重CKPT+网络脚本),但是在一些特定场景,如端侧或者边缘侧大模型,由于运行环境受限,不一定有Python或者MindSpore包的环境下,用户可以使用MindSpore Lite离线推理方案。此时,用户需要将模型导出成MindSpore的统一模型表达MindIR文件,并将其传给MindSpore Lite运行时,具体教程可以参考[Lite推理概述](../lite_infer/overview.md)。 + 由于大语言模型体积巨大,因此MindSpore大语言模型推理推荐用户使用更灵活的在线推理(权重CKPT+网络脚本),但是在一些特定场景,如端侧或者边缘侧大模型,由于运行环境受限,不一定有Python或者MindSpore包的环境下,用户可以使用MindSpore + Lite离线推理方案。此时,用户需要将模型导出成MindSpore的统一模型表达MindIR文件,并将其传给MindSpore + Lite运行时,具体教程可以参考 `Lite推理概述 <../lite_infer/overview.md>`_。 diff --git a/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_network_develop.md b/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_network_develop.md index ddc2904fb9a8d6cda21297559ed6745a83553285..4c9197ba886aefb73cd3b013de3e43fcbe531179 100644 --- a/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_network_develop.md +++ b/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_network_develop.md @@ -26,7 +26,7 @@ MindSpore推荐用户先用动态图模式进行模型开发,然后根据需 - **RmsNorm&Linear**:输出线性归一层,在Transformer结构计算完后,将结果归一成和模型词表一样的维度,最终输出成每个token的概率分布返回。 -使用MindSpore大语言模型推理构建网络,可以根据MindSpore提供的算子自己拼装。下面以Qwen2模型为例,简单描述构建模型的过程,完整端到端样例可以参考[qwen2.py](https://gitee.com/mindspore/docs/blob/master/docs/sample_code/infer_code/qwen2/qwen2.py) +使用MindSpore大语言模型推理构建网络,可以根据MindSpore提供的算子自己拼装。下面以Qwen2模型为例,简单描述构建模型的过程,完整端到端样例可以参考[qwen2.py](https://gitee.com/mindspore/docs/blob/master/docs/sample_code/infer_code/qwen2/qwen2.py)。 ### 基础公共网络层 @@ -706,9 +706,9 @@ class Qwen2Model(nn.Cell): return hidden_state ``` -通过在nn.Cell的construct方法加上ms.jit装饰器,这个Cell的计算就会转化为静态图执行,其中参数意义如下: +通过在nn.Cell的construct方法加上mindspore.jit装饰器,这个Cell的计算就会转化为静态图执行,其中参数意义如下: -- **jit_level**:编译级别,当前MindSpore推理主要支持O0级别,O1级别(会有一些算子融合优化),暂不支持O2级别(整图下沉)。 +- **jit_level**:编译级别,当前MindSpore推理主要支持O0级别、O1级别(会有一些算子融合优化)。 - **infer_boost**:开启推理加速优化,开启后,运行时会做一些调度优化和流优化,提升推理性能。