diff --git a/tutorials/source_en/model_infer/ms_infer/images/embedding2.png b/tutorials/source_en/model_infer/ms_infer/images/embedding2.png new file mode 100644 index 0000000000000000000000000000000000000000..aa172f4311f5dd8eb64a51f7501258649d1b9717 Binary files /dev/null and b/tutorials/source_en/model_infer/ms_infer/images/embedding2.png differ diff --git a/tutorials/source_en/model_infer/ms_infer/ms_infer_model_infer.rst b/tutorials/source_en/model_infer/ms_infer/ms_infer_model_infer.rst new file mode 100644 index 0000000000000000000000000000000000000000..bb0f7e38890642b609a8786a6e5a1188ed4be7c8 --- /dev/null +++ b/tutorials/source_en/model_infer/ms_infer/ms_infer_model_infer.rst @@ -0,0 +1,437 @@ +MindSpore LLM Inference with Framework +========================================== + +.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg + :target: https://gitee.com/mindspore/docs/blob/master/tutorials/source_en/parallel/optimize_technique.rst + :alt: View Source On Gitee + +.. toctree:: + :maxdepth: 1 + :hidden: + + ms_infer_network_develop + ms_infer_parallel_infer + ms_infer_quantization + ms_infer_model_serving_infer + +Background +------------ + +At the end of 2022, with the release of OpenAI's ChatGPT, a new research direction emerged in the AI domain, that is, LLMs based on the Transformers structure. These LLMs exhibited capabilities beyond expectations and achieved impressive results in various tests, quickly becoming the research focus of AI. + +One significant research direction in LLMs is improving their cost-effectiveness in practical applications. + +- An LLM usually has tens of billions of parameters. In this case, the computation workload for a single model inference process is extremely high and requires massive compute resources. As a result, AI service providers find that the cost of an LLM inference is very high and cannot be effectively applied to real-world scenarios. + +- To address the high costs of LLM inference, the MindSpore framework offers inference capabilities. Based on the characteristics of mainstream LLMs, MindSpore has deeply optimized the LLM deployment and inference processes, achieving optimal cost efficiency in model inference. + +Model Principles +------------------------ + +Before learning about the inference capability of MindSpore, first explore how current mainstream LLMs achieve such amazing intelligence. We will take the most common text generation models as examples to briefly introduce the inference principles of LLMs, and see how AI models perform complex tasks such as conversation and summarizing main ideas through computation. + +Similar to a common model, the construction of an LLM consists of two phases: training and inference. + +- **Training**: The training process of an LLM can be simply understood as that a model continuously reading and learning from massive text data. During this process, the model records the position relationship and occurrence frequency of each text element in the model weight. For example, there is a high probability that "9.6 million square kilometers" will appear after the sentence "China has an area of". During the training process, the LLM records that the two sentences are strongly associated through massive data input. + +- **Inference**: The LLM inference process is to find the most relevant subsequent text elements from the training database based on a specific piece of text provided. For example, if you ask "China has an area of", the LLM can return "9.6 million square kilometers" based on the information recorded during training, providing you with your desired answer. + +In actual text processing scenarios, languages are complex and changing. Therefore, it is difficult to identify the direct correlation between two sentences. LLM technologies usually use the tokenization method, that is, breaking down "China has an area of" into multiple common words such as "China", "has", "an", "area", and "of". This method can better cope with the impact of text differences. For example, the similarity between the phrases "the area of China is" and "China has an area of" is nearly 0, while the similarity between ["the", "area", "of", "China", "is"] and ["China", "has", "an", "area", "of"] can be considered as 60%, which can effectively helps the LLM identify such text differences. This technique, known as tokenization, breaks a piece of text into a combination of tokens (usually words and punctuation). The process of generating a sentence is as follows: The LLM infers the next token based on the current token combination, combines the next token with the previous tokens to form a new input, and gradually completes the generation of the entire text through repeated training step. The following table briefly describes an example of LLM inference. + +Input: Capital of China + +.. list-table:: Inference example + :header-rows: 1 + + * - Inference iteration + - Inference input + - Input vector + - Inference result + * - 1 + - China's capital + - [China, 's, capital] + - Beijing + * - 2 + - China's capital, Beijing + - [China, 's, capital, Beijing] + - is + * - 3 + - China's capital, Beijing, is + - [China, 's, capital, Beijing, is] + - Beautiful + * - 4 + - China's capital, Beijing, is beautiful. + - [China, 's, capital, Beijing, is, beautiful] + - END + +In each step of training, the LLM infers the next token based on the current context and combines the token with the previous statement to form the input of the next step of training. After multiple steps of training, if the special token "END" is generated, the model considers that the inference ends, and returns the result. + +Procedure +---------------- + +MindSpore LLM inference provides you with an "out-of-the-box" deployment and inference capability. You can use the LLM APIs provided by MindSpore to quickly deploy your own LLMs and optimize them based on model features, achieving the optimal cost-effectiveness and bringing LLM capabilities to practical applications. The following figure shows the key steps of model inference using the MindSpore LLM inference feature. + +.. figure:: ./images/llm_infer_flow.png + :alt: llm-infer-flow + +1. **Weight preparation**: The weight data is the intelligent core of an LLM, and therefore the first step of deploying a model is to obtain and prepare the corresponding weight files. +2. **Model loading**: During inference, the model structure may differ based on the optimization techniques used. Therefore, the backbone network of the model needs to be constructed based on the model network structure to facilitate subsequent inference. +3. **Status determination**: Based on the specific semantics of the inference request, the model determines whether to continue with inference. This process is mainly used to determine whether to end multi-step inference. If inference ends (for example, after answering a question), the results are returned; otherwise, the next step of inference continues. +4. **Inference preprocessing**: The inference data is preprocessed according to the inference request. Common preprocessing steps include using a tokenizer to convert the statement into a group of digital vectors represented by indexes, allowing the LLM to accurately recognize the task content, and constructing some special input of model inference for acceleration (for example, cache information of incremental inference of KVCache). +5. **Model inference**: The model performs inference based on the input data, typically returning the probability distribution of the next token in the sentence. +6. **Inference postprocessing**: Based on the results of the model inference, the next token is computed and converted back into text. If inference does not end, the token is assembled into the input for the next step of inference to continue the process. + +Main Features +---------------- + +To achieve the optimal cost-effectiveness, MindSpore LLM has undergone multiple in-depth optimizations tailored to the characteristics of LLM networks. The main features include: + +- **Full and incremental inference**: The core network structure of LLMs primarily utilizes a transformer-based self-attention mechanism, where attention scores of all tokens are computed in each training step. However, the attention scores of the same token sequence yield the same key and value (KV) results. For example, the KV of ["the", "area", "of", "China", "is"] may be understood as a combination of ["the", "area", "of", "China"] and ["is"]. Therefore, by caching the keys and values of previously computed sequences, the computation workload for the next training step can be reduced. This technique is commonly known as KVCache optimization. In two consecutive training steps, *N* and *N* +1, the KVs from training step *N* can be fully reused in training step *N* +1 because the first *N* sequences are identical and only the first token of *N* +1 steps needs to be computed. In this way, the model inference can be divided into the following two phases: + + - **Full inference**: This is the first training step initiated by your input, where the length *N* of the input statement and the content is unpredictable. All keys and values must be computed, which is called a full inference. + + - **Incremental inference**: After completing the first training step, the keys and values from the previous statement are stored in the KVCache. In this case, only the KV corresponding to the latest token need to be computed, which are then combined with the cached result to compute the attention score, constituting an incremental inference. + +- **Attention optimization**: The primary computation in the LLM's network involves the computation of attention. Since the attention size in mainstream models is often large (typically 4096 x 4096 or more), the performance of the entire inference process heavily relies on the efficiency of attention computation. Many studies focus on optimizing the performance of attention computation, with notable techniques such as flash attention and page attention. + + - **Flash attention**: During attention computation, two large matrices (4096 x 4096) are multiplied. This computation breaks the large matrix into smaller matrices that can be processed on multiple chips. Subject to the minimum cache size of chips, data must continuously be moved between the cache and main memory. As a result, compute resources cannot be fully used. Consequently, attention computation is often bandwidth-bound. Flash attention addresses this by dividing attention into blocks, allowing each block to be computed independently on a chip, avoiding multiple data movements during the computation of KVs and enhancing attention computation performance. For details, see `Flash Attention `_. + + - **Page attention graphics memory optimization**: Standard flash attention reads and saves the entire input KV data each time. This method is simple but wastes many resources. For example, "China's capital" and "China's national flag" share "China's", leading to identical KVs for their attention. Standard flash attention needs to store two copies of KVs, wasting the graphics memory. Page attention optimizes KVCache based on the page table principle of the Linux OS. It stores KVs in blocks of a specific size. In the preceding example, "China", "'s", "capital", and "national flag" are stored as four pieces of KV data. Compared with the original six pieces of data, this method effectively saves graphics memory resources. In the service-oriented scenario, more idle graphics memory allows for a larger batch size for model inference, thereby achieving higher throughput. For details, see `Page Attention `_. + +- **Model quantization**: MindSpore LLM inference supports quantization to reduce the model size. It provides technologies such as A16W8, A16W4, A8W8, and KVCache quantizations to reduce model resource usage and improve the inference throughput. + +Inference Tutorial +------------------------ + +Based on the mainstream Qwen2 open-source LLM, this section demonstrates how to use the inference capability of the MindSpore model to build an example of end-to-end text generation. + +.. note:: + + The Qwen2 model has multiple versions and configurations. This document uses Qwen2-7B-Instruct as an example. + +Environment Preparations +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +MindSpore LLM inference with the framework mainly depends on the MindSpore open-source software. Before using the framework, you need to install the MindSpore Python package. You are advised to use the conda virtual environment. You can run the following commands for installation: + +.. code:: shell + + export PYTHON_ENV_NAME=mindspore-infer-py311 + conda create -n ${PYTHON_ENV_NAME} python=3.11 + conda activate ${PYTHON_ENV_NAME} + pip install mindspore + +You can also install the Python package adapted to your environment by referring to the official installation document. For details, see `MindSpore Installation `_. + +MindSpore inference mainly runs on the Ascend AI Processor environment. You need to install the corresponding Ascend development environment. For details, see the following: + +.. code:: shell + + pip install ${ASCEND_HOME}/lib64/te-*.whl + pip install ${ASCEND_HOME}/lib64/hccl-*.whl + pip install sympy + +If you need to reuse the tokenizer capability of the mainstream LLM, you can install the Transformers software package. + +.. code:: shell + + pip install transformers + +If you need to use model quantization to enhance inference performance, you need to install the mindspore_gs package. For details, see `Installing MindSpore Golden Stick `_. + +Weight Preparation +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Obtain the weight file of the LLM for weight preparation. In addition, each LLM usually has its own token list, which indicates a full set of words supported by the model. Therefore, you need to obtain the tokenizer mapping in addition to the model weight. MindSpore supports the direct loading of the safetensor weight file. You can directly download the model weight file from the Hugging Face official website. + +For the Qwen2 LLM, you are advised to use the pre-trained weight files and tokenizer mapping provided on the Hugging Face official website. You can run the following commands to download weights: + +.. code:: shell + + git lfs install + git clone https://huggingface.co/Qwen/Qwen2-7B-Instruct + +After the download is complete, the following file tree structure should be displayed in the related directory: + +.. code:: shell + + ls + |- config.json + |- LICENSE + |- merges.txt + |- model-00001-of-00004.safetensors + |- model-00002-of-00004.safetensors + |- model-00003-of-00004.safetensors + |- model-00004-of-00004.safetensors + |- model.safetensors.index.json + |- README.md + |- tokenizer_config.json + |- tokenizer.json + |- vocab.json + +Model Building +~~~~~~~~~~~~~~~~~~~~ + +You need to build a model and load the weight by running the following codes first: + +.. code:: python + + import os + import mindspore as ms + from qwen2 import Qwen2Config, Qwen2ForCausalLM, CacheManager + from mindspore import Tensor, mint + + # set mindspore context and envs + os.environ["MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST"] = "PagedAttention" + + ms.set_context(infer_boost="on") + ms.set_context(mode=ms.context.PYNATIVE_MODE) + + model_path = "/path/to/model" + input_str = ["I love Beijing, because", "Hello, Qwen2"] + batch_size = len(input_str) + max_new_tokens = 64 + block_size = 128 + max_seq_lens = block_size * 10 + block_num = (max_seq_lens * batch_size) // block_size + + config = Qwen2Config.from_json(model_path + "/config.json") + + model = Qwen2ForCausalLM(config) + # load weight + model.load_weight(model_path) + + cache_manager = CacheManager(config, block_num, block_size, batch_size) + +Qwen2 is the network script (qwen2.py) of the model, which must be in the same directory as the current script. For details, see `Building an LLM Inference Network from Scratch <./ms_infer_network_develop.md>`_. You can also use other network scripts, but you need to modify the corresponding model APIs. + +The first step in the script is to set MindSpore environment variables, including: + +- **MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST**: sets the TH flattening operator supported by MindSpore for PagedAttention. MindSpore only supports the TH format in dynamic graph mode. Therefore, if you want to develop in dynamic graph mode, you need to set this environment variable. You can also use the BSH format. + +- **infer_boost**: enables inference optimization. This optimization is mainly to enable MindSpore fusion operators such as FlashAttention and PagedAttention. + +- **mode**: sets the execution mode to dynamic graph mode. This mode is more convenient for debugging and development. You are advised to use this mode during model development. + +The second step in the script is to initialize the model and KVCache using the class provided by the model script **qwen2.py**. The following parameters are included: + +- **input_str**: specifies the original text to be inferred. A string list with **batch_size** set to **2** is passed at a time, indicating that two statements are inferred at the same time. + +- **model_path**: specifies the model directory path, that is, the path of the model downloaded from the Hugging Face official website. + +- **max_new_tokens**: specifies the maximum number of inference words. When the number of inference words reaches the maximum, the inference stops and is used in subsequent iterations. + +- **block_size**: specifies the block size of the KVCache object managed by PagedAttention. A smaller value of **block_size** indicates finer division and higher reuse probability of different requests. A larger value of **block_size** indicates that more valid data is read at a time during network computing, and the computing performance is better. + +- **max_seq_len**: specifies the maximum length supported by model inference. This parameter can be obtained from **config** and affects the graphics memory usage of KVCache. The Qwen2 configuration is large (32,000) by default. Therefore, this parameter is set to 10 times the value of **block_size** for simplification. + +Initialize the model based on the preceding parameters to obtain the model and cache_manager objects. + +Model Inference +~~~~~~~~~~~~~~~~~~~~ + +Once the model is built, you can utilize the model object for text generation, enabling applications such as self-service customer support, intelligent Q&A, and chatbots. However, the input of an application is usually a language text, which cannot be directly used as the input of the model for computation. Therefore, we need to add the preprocessing and postprocessing logic to convert the text language into token data that can be identified by the model. After the inference computation is complete, the token data is converted into the text language. The following uses a simple Q&A text generation as an example to describe the process. + +- **Preprocessing**: Use the tokenizer's data to break a sentence down into a list represented by multiple token IDs. In this case, the tokenizer of the open-source community Transformers is used. + + .. code:: python + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + input_str = ["I love Beijing, because", "Hello, Qwen2"] + + input_ids = tokenizer(input_str)["input_ids"] + + print(input_ids) + + After the Python code is executed, the following information is displayed: + + .. code:: shell + + [[40, 2948, 26549, 11, 1576], [9707, 11, 1207, 16948, 17]] + + [40, 2948, 26549, 11, 1576] corresponds to the word sequence "I love Beijing, because". **40** indicates the token corresponding to "I", **2948** indicates the token corresponding to "love", **26549** indicates the token corresponding to "Beijing", **11** indicates the token corresponding to ", " (comma and space), and **1576** indicates the token corresponding to "because". This format can be directly passed to the model for inference. Similarly, [9707, 11, 1207, 16948, 17] corresponds to the input sequence "Hello, Qwen2". In this example, two requests are passed at a time for batch calculation. + +- **Entire network computing**: The data and configuration of the current input token are specified so that the model object can iteratively infer the token result of each step through multiple rounds of computation. To simplify the code, you can encapsulate the iterative inference into the following generate function: + + .. code:: python + + from typing import List + from mindspore import ops, mint, Tensor, dtype + from qwen2 import Qwen2Config, Qwen2ModelInput, Qwen2ForCausalLM, CacheManager, sample + + def generate(model: Qwen2ForCausalLM, cache_manager: CacheManager, input_ids: List, max_new_tokens: int, max_seq_lens: int, eos_token_id: int): + batch_size = len(input_ids) + assert max_seq_lens >= max(map(len, input_ids)) + + cur = min(map(len, input_ids)) + is_prefill = True + it = 0 + + decode_q_seq_lens = Tensor([1 for _ in range(batch_size)], dtype=dtype.int32) + decode_mask = ops.zeros((1, 1), dtype=config.param_dtype) + attn_mask = None + q_seq_lens = None + + while cur <= max_seq_lens and it < max_new_tokens: + batch_valid_length = Tensor([cur for _ in range(batch_size)], dtype=dtype.int32) + if is_prefill: + inp = Tensor([input_ids[i][:cur] for i in range(batch_size)], dtype=dtype.int32) + pos = mint.arange(cur).astype(dtype.int32) + block_tables, slot_mapping = cache_manager.step(0, cur) + attn_mask = ops.logical_not(ops.sequence_mask(pos + 1, cur)).astype(config.param_dtype) + q_seq_lens = None + else: + inp = Tensor([[input_ids[i][cur - 1]] for i in range(batch_size)], dtype=dtype.int32) + pos = Tensor([[cur - 1] for _ in range(batch_size)], dtype=dtype.int32).view(-1) + block_tables, slot_mapping = cache_manager.step(cur - 1, 1) + attn_mask = decode_mask + q_seq_lens = decode_q_seq_lens + + model_input = Qwen2ModelInput( + input_ids=inp, + positions=pos, + batch_valid_length=batch_valid_length, + is_prefill=is_prefill, + attn_mask=attn_mask, + k_caches=cache_manager.k_caches, + v_caches=cache_manager.v_caches, + block_tables=block_tables, + slot_mapping=slot_mapping, + q_seq_lens=q_seq_lens + ) + + logits = model(model_input) + + next_tokens = sample(logits) + + for i in range(batch_size): + if cur >= len(input_ids[i]): + input_ids[i].append(int(next_tokens[i])) + + cur += 1 + it += 1 + if is_prefill: + is_prefill = False + + for i in range(batch_size): + if eos_token_id in input_ids[i]: + eos_idx = input_ids[i].index(eos_token_id) + input_ids[i] = input_ids[i][: eos_idx + 1] + + return input_ids + + The generate function simulates the iteration process of LLM inference. The core steps are as follows: + + 1. **Model input preparation**: Prepare the input data required for model inference and construct the Qwen2ModelInput object. The main parameters are as follows: + + **input_ids**: specifies the list of input vocabulary IDs. Each batch is represented by a list. + + **positions**: specifies position information of the input vocabulary in the inference statement, which is mainly used for RoPE. + + **batch_valid_length**: specifies the length of the current inference statement, which is used to obtain the KV of KVCache. Generally, the value is the value of **positions** plus 1. In speculative inference scenarios, the value may be greater than the value of **positions** plus 1. + + **is_prefill**: specifies whether full inference is performed. Full inference needs to compute multiple KVs. Incremental inference can reuse the KV results computed in the previous computation, and only the last KV needs to be computed. + + **attn_mask**: hides unnecessary information during attention score computation. It is usually a standard matrix with an upper or lower triangle (valid elements are marked with **1** and others are **0**). + + **kv_caches**: specifies the KVCache object, which stores all computed KV results. + + **block_tables&slot_mapping**: specifies the KVCache information used by the current inference vocabulary. **block_tables** indicates the block used by each batch, and **slot_mapping** indicates the position of the corresponding word in the block. For example, if **block_tables** is **[2, 10]**, **slot_mapping** is **[1200]**, and **block_size** is **128**, the second and tenth blocks are used for inference, and the 1200th block unit is used for the current word, that is, the KV of the 48th unit in the tenth block. + + **q_seq_lens**: specifies the length of the query in attention, which is mainly used by the PagedAttention operator. The value is **1** in the standard model, and may be greater than 1 in speculative inference scenarios. + + 2. **Model calculation**: Call the main model network to start the model computation logic and compute the probability distribution of the next word. + + 3. **Sampling result**: Obtain the ID of the next word through sampling computing (**argmax** is used as an example, that is, the word with the highest probability is selected). + + 4. **Input update of the next iteration**: Update the word list of the next iteration and enter the next iteration. + + After the iteration is complete, you can optimize the model. The model inference ends based on the number of inference words. The inference result may be suddenly interrupted. Therefore, you can use the tokenizer's sentence segmentation table ID to enclose the result at the position of the last sentence segmentation (for example, period) to enhance the readability of the text result. After the encapsulation is complete, you can call the word generation process using the following code: + + .. code:: python + + output = generate( + model=model, + cache_manager=cache_manager, + input_ids=input_ids, + max_new_tokens=max_new_tokens, + eos_token_id=tokenizer.eos_token_id, + max_seq_lens=max_seq_lens + ) + +- **Postprocessing**: Based on the network inference output, use the conversion capability of the tokenizer to convert the token ID list into a comprehensible statement. + + .. code:: python + + result = [tokenizer.decode(a) for a in output] + print(result) + + After the Python code is executed, the following information is displayed: + + .. code:: shell + + I love Beijing, because it is a city that is constantly changing. I have been living here for 10 years and I have seen the city changes so much. ... + + It can be seen that the model-inferred token IDs are translated to a human-readable statement. In actual verification, due to the randomness of **do_sample**, each inference is different, but the result logic is basically understandable. + + For details about the complete end-to-end example, see `infer.py `_. + +Model Parallelism +~~~~~~~~~~~~~~~~~~~~ + +For LLMs with many model parameters, such as Llama2-70B and Qwen2-72B, the parameter scale usually exceeds the memory capacity of a GPU or NPU. Therefore, multi-device parallel inference is required. MindSpore LLM inference can shard the original LLM into *N* parallel models so that they can be executed on multiple devices in parallel. This not only enables inference for super LLMs but also enhances performance by leveraging more resources from the multiple devices. The model scripts provided by the MindFormers model suite can be used to shard a model into multi-device models for execution. + +Currently, mainstream model parallel methods include the following: + +- **Data parallelism**: The data to be computed is divided into multiple parallel parts and computed on multiple devices in parallel. In the inference scenario, multiple statements can be computed in parallel through batch processing. Data parallelism can be understood as multiple model instances executed in parallel, and therefore no additional model adaptation is required. + +- **Tensor parallelism**: The operators to be computed by the model are sharded according to the network script definition. In the inference scenario, the number of shards is usually equal to the number of devices. The input and output of operator computation in the network change with the parallelism degree. Therefore, the model needs to be adapted to the parallelism. + +- **Pipeline parallelism**: The model is sharded into multiple instances based on the number of layers. Pipeline computation can be implemented between multiple requests. The network is sharded into multiple subnets. Therefore, the model needs to be adapted to the parallelism. + +- **Expert parallelism**: This is a parallel strategy specific to MoE LLMs. Different expert computations are distributed to different compute entities in parallel, and the computing performance is improved through concurrent expert control. + +To more clearly describe the model parallel computing process, this section describes the most basic and common model parallel policies. You can implement parallel adaptation of the model by performing the following steps: + +1. **Model adaptation**: When a MindSpore LLM is running on multiple devices, model parallelism is usually used. Therefore, the original model needs to be sharded based on the number of devices. For example, the matrix multiplication of [1024, 4096] and [4096, 2048] can be sharded into two matrix multiplications of [1024, 4096] and [4096, 1024], respectively. + Different sharding policies may bring different parallel computing performance. + For Qwen and Llama, the sharding mainly involves the linear operations on the query, key, and value data at the attention layer. + +2. **Weight adaptation**: In addition to the parallel reconstruction of the model structure, the weights in the model computation are also sharded. Therefore, the related weights need to be sharded during model loading to minimize the graphics memory occupied by unnecessary weight loading. For LLMs, the main weights are concentrated on the embedding and linear network layers. Therefore, the weight loading adaptation mainly involves the reconstruction of the two modules. + +3. **Model inference**: Unlike single-device inference, multi-device inference requires multiple processes to be started at the same time for parallel inference. Therefore, when starting model inference, multi-device inference requires running multiple groups of related processes at a time, instead of directly running scripts. The MindSpore framework provides the msrun parallel running tool. For details, see `Building a Parallel LLM Network <./ms_infer_parallel_infer.md>`_. + +Model Quantization +~~~~~~~~~~~~~~~~~~~~ + +The MindSpore LLM supports the following quantization technologies to improve the inference performance: + +- **A16W8/A16W4 quantization**: quantizes the weights of an LLM, saving float16 weights as 8-bit int8 or 4-bit int4 data. Before computation, the weights are de-quantized back to float16, reducing memory usage, enhancing model concurrency, and improving inference throughput. + +- **A8W8 quantization**: quantizes the entire network of an LLM, converting float16 activations to 8-bit int8 data for computation. This doubles the computational efficiency of GPU or NPU computing units (for example, from 16 x 16 to 32 x 16). Specific quantization operators are required. This not only reduces memory usage but also significantly enhances computational performance. + +- **KVCache quantization**: reduces graphics memory consumption, effectively enhancing overall throughput. (KVCache consumes considerable graphics memory and model weights in LLM inference.) MindSpore supports quantizing KVCache from float16 to int8. Through flash attention and page attention, quantization and dequantization are fused into operators to reduce the overhead caused by quantization and improve the overall throughput. + +To quantize a model using golden-stick, perform the following steps: + +1. **Weight quantization**: Use a quantization algorithm to convert the model weight data from float16 to int8. + +2. **Model inference**: Load the standard model, quantize the model network (by inserting corresponding quantization operators), load the quantized weight, and call the model inference. + +For details about model quantization, see `Quantization <./ms_infer_quantization>`_. + +Advanced Usage +----------------- + +- **Using custom operators to optimize model inference** + + The MindSpore LLM inference supports the use of custom operators to optimize operators in specific scenarios or implement operator fusion on the network. Custom operators can be enabled or disabled by simply modifying the operator API in the network script. For details, see `Custom Operators <../../custom_program/operation/op_custom_ascendc.md>`_. + +- **Offline inference of LLMs** + + Given the substantial size of LLMs, you are advised to use more flexible online inference (weight CKPT and network script) for MindSpore LLM inference. However, in specific scenarios, such as running device or edge LLMs with limited running environments lacking Python or MindSpore packages, you can use the MindSpore Lite offline inference solution. + + In this case, you need to export the model to a MindIR file, which is the unified model expression of MindSpore, and send the file to the MindSpore Lite runtime. For details, see `Lite Inference Overview <../lite_infer/overview.md>`_. diff --git a/tutorials/source_en/model_infer/ms_infer/ms_infer_model_serving_infer.md b/tutorials/source_en/model_infer/ms_infer/ms_infer_model_serving_infer.md new file mode 100644 index 0000000000000000000000000000000000000000..7fef1b3ad25238ae825a00227cf79cdc144562d8 --- /dev/null +++ b/tutorials/source_en/model_infer/ms_infer/ms_infer_model_serving_infer.md @@ -0,0 +1,166 @@ + +# Service-oriented Model Inference + +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/tutorials/source_en/model_infer/ms_infer/ms_infer_model_serving_infer.md) + +## Background + +MindSpore is an AI model development framework that provides efficient model development capabilities. Generally, the following code is used for model inference: + +```python + +input_str = "I love Beijing, because" + +model = Qwen2Model(config) +model.load_weight("/path/to/model") + +input_ids = tokenizer(input_str)["input_ids"] + +logits = model(input_ids) + +next_token = ops.argmax(logits) + +generate_text = tokenizer.decode(next_token) + +print(generate_text) +``` + +This model inference mode is simple, but the model and weight need to be reloaded each time inference is performed. As a result, the inference efficiency is low in actual applications. To solve this problem, a model inference backend service is usually deployed to receive inference requests online and send requests to the model for computing. This inference mode is called service-oriented inference. MindSpore does not provide the service-oriented inference capability. If service-oriented inference is required in actual applications, you need to develop a service backend and integrate the related model. + +To help users easily deploy out-of-the-box model inference capabilities in the production environment, MindSpore provides full-stack service-oriented model inference capabilities based on the popular vLLM model inference open-source software. Service-oriented inference supports real-time online inference and efficiently improves the overall throughput of model inference and reduces inference costs through efficient user request scheduling. + +## Main Features + +As an efficient service-oriented model inference backend, it should provide the following capabilities to maximize the deployment and running efficiency of models: + +- **Quick startup**: Quick loading and initialization of LLMs are implemented through technologies such as compilation cache and parallel loading, reducing the extra startup overhead caused by the continuous increase of model weights. + +- **Batch inference**: A proper batch grouping mechanism is used to implement optimal user experience in the case of massive concurrent requests. + +- **Efficient scheduling**: Full and incremental request scheduling is used to address full and incremental inference requirements of LLMs, maximizing resource computing efficiency and improving system throughput. + +## Inference Tutorial + +MindSpore inference works with the vLLM community solution to provide users with full-stack end-to-end inference service capabilities. The vLLM MindSpore adaptation layer implements seamless interconnection of the vLLM community service capabilities in the MindSpore framework. For details, see [vLLM MindSpore](https://www.mindspore.cn/vllm_mindspore/docs/en/master/index.html). + +This section describes the basic usage of vLLM MindSpore service-oriented inference. + +### Setting Up the Environment + +The vLLM MindSpore adaptation layer provides an environment installation script. You can run the following commands to create a vLLM MindSpore operating environment: + +```shell +# download vllm-mindspore code +git clone https://gitee.com/mindspore/vllm-mindspore.git +cd vllm-mindspore + +# create conda env +conda create -n vllm-mindspore-py311 python=3.11 +conda activate vllm-mindspore-py311 + +# install extra dependent packages +pip install setuptools_scm +pip install numba + +# run install dependences script +bash install_depend_pkgs.sh + +# install vllm-mindspore +python setup.py install +``` + +After the vLLM MindSpore operating environment is created, you need to install the following dependency packages: + +- **mindspore**: MindSpore development framework, which is the basis for model running. + +- **vLLM**: vLLM service software. + +- **vllm-mindspore**: vLLM extension that adapts to the MindSpore framework. It is required for running MindSpore models. + +- **msadapter**: adaptation layer for MindSpore to connect to PyTorch. Some vLLM functions depend on the PyTorch capabilities and need to be adapted by MSAdapter. + +- **golden-stick**: MindSpore model quantization framework. If the quantization capability is required, install this software. + +- **mindformers**: Transformer model library provided by the MindSpore framework. You can use the models directly or connect to the native models of MindSpore. + +### Preparing a Model + +The service-oriented vLLM MindSpore supports the direct running of the native Hugging Face model. Therefore, you can directly download the model from the Hugging Face official website. The following uses the Qwen2-7B-Instruct model as an example: + +```shell +git lfs install +git clone https://huggingface.co/Qwen/Qwen2-7B-Instruct +``` + +If `git lfs install` fails during the pull process, refer to the vLLM MindSpore FAQ for a solution. + +### Starting a Service + +Before starting the backend service, you need to set the environment variables based on the actual environment. + +```shell +# set Ascend CANN tools envs +source /usr/local/Ascend/ascend-toolkit/set_env.sh +export ASCEND_CUSTOM_PATH=${ASCEND_HOME_PATH}/../ +export ASCEND_RT_VISIBLE_DEVICES=3 +export ASCEND_TOTAL_MEMORY_GB=32 + +# mindspore envs +export MS_ALLOC_CONF=enable_vmm:true +export CPU_AFFINITY=0 + +# vLLM envs +export VLLM_MODEL_MEMORY_USE_GB=26 + +# backend envs +export VLLM_MASTER_IP=127.0.0.1 +export VLLM_RPC_PORT=12390 +export VLLM_HTTP_PORT=8080 +unset vLLM_MODEL_BACKEND + +# model envs +export MODEL_ID="/path/to/model/Qwen2-7B-Instruct" +``` + +Run the following command to start the vLLM MindSpore service backend: + +```shell +vllm-mindspore serve --model=${MODEL_ID} --port=${VLLM_HTTP_PORT} --trust_remote_code --max-num-seqs=256 --max_model_len=32768 --max-num-batched-tokens=4096 --block_size=128 --gpu-memory-utilization=0.9 --tensor-parallel-size 1 --data-parallel-size 1 --data-parallel-size-local 1 --data-parallel-start-rank 0 --data-parallel-address ${VLLM_MASTER_IP} --data-parallel-rpc-port ${VLLM_RPC_PORT} &> vllm-mindspore.log & +``` + +After the backend service is loaded, the listening port and provided APIs of the backend service are displayed. + +### Sending a Request + +You can run the following command to send an HTTP request to implement model inference: + +```shell +curl http://${VLLM_MASTER_IP}:${VLLM_HTTP_PORT}/v1/completions -H "Content-Type: application/json" -d "{\"model\": \"${MODEL_ID}\", \"prompt\": \"I love Beijing, because\", \"max_tokens\": 128, \"temperature\": 1.0, \"top_p\": 1.0, \"top_k\": 1, \"repetition_penalty\": 1.0}" +``` + +After receiving the inference request, the service backend calculates and returns the following results: + +```json +{ + "id":"cmpl-1c30caf453154b5ab4a579b7b06cea19", + "object":"text_completion", + "created":1754103773, + "model":"/path/to/model/Qwen2-7B-Instruct", + "choices":[ + { + "index":0, + "text":" it is a city with a long history and rich culture. I have been to many places of interest in Beijing, such as the Great Wall, the Forbidden City, the Summer Palace, and the Temple of Heaven. I also visited the National Museum of China, where I learned a lot about Chinese history and culture. The food in Beijing is also amazing, especially the Peking duck and the dumplings. I enjoyed trying different types of local cuisine and experiencing the unique flavors of Beijing. The people in Beijing are friendly and welcoming, and they are always willing to help tourists. I had a great time exploring the city and interacting with the locals", + "logprobs":null, + "finish_reason":"length", + "stop_reason":null, + "prompt_logprobs":null + } + ], + "usage":{ + "prompt_tokens":5, + "total_tokens":133, + "completion_tokens":128, + "prompt_tokens_details":null + } +} +``` diff --git a/tutorials/source_en/model_infer/ms_infer/ms_infer_network_develop.md b/tutorials/source_en/model_infer/ms_infer/ms_infer_network_develop.md new file mode 100644 index 0000000000000000000000000000000000000000..86ca7cb2ac0181e2b9753fd25d8a89d84241b752 --- /dev/null +++ b/tutorials/source_en/model_infer/ms_infer/ms_infer_network_develop.md @@ -0,0 +1,723 @@ +# Building an LLM Inference Network from Scratch + +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/tutorials/source_en/model_infer/ms_infer/ms_infer_network_develop.md) + +## Model Development Modes + +MindSpore provides two model running modes: + +- **Static graph mode**: The model network is compiled into a complete network graph for convergence and optimization, improving the model execution performance. However, due to some syntax support issues, model development has certain limitations, affecting the usability. + +- **Dynamic graph mode**: Python statements of network scripts are executed one by one, facilitating printing and debugging (by using the PDB) at any time. This mode is easy to use, but its performance is not as good as that of the static graph mode. + +In MindSpore, you are advised to use the dynamic graph mode to develop a model and then convert dynamic graphs to static graphs as required to obtain the maximum model performance. + +## Backbone Network Used for Development in Dynamic Graph Mode + +Most mainstream LLMs use the Transformer-based backbone network, where core computing relies on the self-attention mechanism. The following figure uses the Qwen2 LLM as an example to show the backbone network architecture. + +![Qwen2 network architecture](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/model_infer/ms_infer/images/llm_qwen2_network_arch.png) + +The core layer of Qwen2 consists of the following parts: + +- **Embedding**: converts the index corresponding to each token into a vector to implement feature dispersion. Similar to one-hot vectorization, the embedding weights are involved in the training process, which can better adapt to the context semantics in the LLM. This process is implemented through the embedding operator. + +- **DecodeLayer**: refers to the Transformer structure, which is a key compute module of the LLM. Generally, multiple layers of computation are configured as needed. Each layer is actually a Transformer structure. + +- **RmsNorm & Linear**: linearly normalizes the output of each layer to the same dimension as the model vocabulary after computation by the transformer structure and returns the probability distribution of each token. + +You can use the MindSpore LLM to build a network for inference. The network can be assembled as required using operators provided by MindSpore. The following uses the Qwen2 model as an example to describe how to build a model. For details about the complete end-to-end example, see [qwen2.py](https://gitee.com/mindspore/docs/blob/master/docs/sample_code/infer_code/qwen2/qwen2.py). + +### Basic Common Network Layer + +The Qwen2 LLM has many configurations and parameters. To manage them more conveniently, you need to define the Config and Input classes to be used by the model. In addition, note that the Linear and RmsNorm operators are frequently used in each functional layer of the network. You can build these common layers in advance. + +#### Config & Input + +```python +import json +from dataclasses import dataclass +from typing import Optional, Type, List, Tuple, Union + +from mindspore import Tensor, dtype + +@dataclass +class Qwen2Config: + """Qwen2 Config, the key-value is almost the same with config.json in Hugging Face""" + architectures: Optional[List[str]] = None + attention_dropout: float = 0.0 + bos_token_id: int = 151643 + eos_token_id: int = 151645 + hidden_act: str = "silu" + hidden_size: int = 3584 + initializer_range: float = 0.02 + intermediate_size: int = 18944 + max_position_embeddings: int = 32768 + max_window_layers: int = 28 + model_type: str = "qwen2" + num_attention_heads: int = 28 + num_hidden_layers: int = 28 + num_key_value_heads: int = 4 + rms_norm_eps: float = 1e-06 + rope_theta: float = 1000000.0 + sliding_window: Optional[int] = 131072 + tie_word_embeddings: bool = False + torch_dtype: str = "bfloat16" + transformers_version: str = "4.41.2" + use_cache: bool = True + use_sliding_window: bool = False + vocab_size: int = 152064 + param_dtype: Optional[Type] = dtype.bfloat16 # this is mindspore datatype as hugging face use str as dtype + + @classmethod + def from_json(cls, json_path: str) -> 'Qwen2Config': + with open(json_path) as f: + data = json.load(f) + config = cls(**data) + return config + + +@dataclass +class Qwen2ModelInput: + input_ids: Tensor + positions: Tensor + batch_valid_length: Tensor + is_prefill: bool + attn_mask: Tensor + k_caches: List[Tensor] + v_caches: List[Tensor] + slot_mapping: Tensor = None + block_tables: Tensor = None + hidden_state: Optional[Tensor] = None + residual: Optional[Tensor] = None + q_seq_lens: Optional[Tensor] = None +``` + +The Qwen2Config configuration is basically the same as that of Hugging Face. For details, see the official Qwen2 documentation. Note that **param_dtype** is used to replace **torch_dtype** in Qwen2Config because the data types of MindSpore are different from those of PyTorch. Qwen2ModelInput defines the model input, including the word ID, KVCache, and Attention fused operator, which are required by MindSpore inference optimization features. + +#### RmsNorm + +RmsNorm is a normalization algorithm commonly used in most LLMs. MindSpore provides operators that can be directly used. You only need to create the corresponding weights. In addition, RmsNorm often involves residual computing. The RmsNorm class implements residual converged computing at the network layer. The following is a code example: + +```python +from typing import Optional, Type, Union, Tuple + +from mindspore import nn, ops, mint, Parameter, Tensor + +class RmsNorm(nn.Cell): + def __init__(self, config: Qwen2Config) -> None: + super().__init__() + + self.rms_norm = ops.RmsNorm(config.rms_norm_eps) + + self.weight = Parameter( + mint.ones( + config.hidden_size, + dtype=config.param_dtype + ), + requires_grad=False + ) + + def construct(self, x: Tensor, residual: Optional[Tensor] = None) -> Union[Tensor, Tuple[Tensor, Tensor]]: + if residual is not None: + x = x + residual + residual = x + output = self.rms_norm(x, self.weight)[0] + if residual is None: + return output + return output, residual +``` + +#### Linear + +The Linear layer is actually a linear transformation. Its main computing logic is matrix multiplication (MatMul). However, bias correction may be required for addition depending on the specific application scenario (bias is required during query, key, and value conversion). The following code integrates these computations into a network structure: + +```python +from typing import Optional, Type + +from mindspore import nn, ops, mint, Parameter, Tensor + +class Qwen2Linear(nn.Cell): + def __init__(self, input_size: int, output_size: int, param_dtype: Optional[Type], enable_bias: bool) -> None: + super().__init__() + + self.param_dtype = param_dtype + self.input_size = input_size + self.output_size = output_size + self.enable_bias = enable_bias + + self.matmul = ops.MatMul(transpose_b=True) + self.weight = Parameter( + mint.zeros( + (self.output_size, self.input_size), + dtype=self.param_dtype + ), + requires_grad=False + ) + + if self.enable_bias: + self.bias_add = ops.Add() + self.bias = Parameter( + mint.zeros(self.output_size, dtype=self.param_dtype) + ) + + def construct(self, input: Tensor): + origin_shape = input.shape + x = self.matmul(input.view(-1, origin_shape[-1]), self.weight) + if self.enable_bias: + x = self.bias_add(x, self.bias) + return x.view(*origin_shape[:-1], -1) +``` + +Because multi-batch computation is required, the input **shape** may be *n* times of **input_size**. To ensure correct computation, the original input **shape** is saved. After the computation is complete, **shape** is restored through the view. + +### Qwen2ForCausalLM + +The Qwen2 model is usually encapsulated for specific services. For example, Qwen2ForCausalLM is an encapsulation of Qwen2 for language processing and dialog services. + +The Qwen2ForCausalLM class is used to clearly define the main APIs of the model. The following shows the specific implementation: + +```python +from glob import glob +from typing import Optional, Type + +from mindspore import nn, Tensor, load_checkpoint, load_param_into_net + +class Qwen2ForCausalLM(nn.Cell): + def __init__(self, config: Qwen2Config) -> None: + super().__init__() + + self.model = Qwen2Model(config=config) + self.lm_head = Qwen2Linear( + input_size=config.hidden_size, + output_size=config.vocab_size, + param_dtype=config.param_dtype, + enable_bias=False + ) + + def load_weight(self, weight_path: str) -> None: + weight_dict = {} + for path in glob(weight_path + "/*.safetensors"): + weight_dict.update(load_checkpoint(path, format="safetensors")) + + load_param_into_net(self, weight_dict, strict_load=False) + + def construct(self, model_input: Qwen2ModelInput) -> Tensor: + hidden_state = self.model(model_input.input_ids, model_input.positions, + model_input.batch_valid_length, model_input.is_prefill, + model_input.k_caches, model_input.v_caches, model_input.slot_mapping, + model_input.block_tables, model_input.attn_mask, model_input.q_seq_lens) + logits = self.lm_head(hidden_state)[:, -1] + return logits +``` + +As shown in the code, Qwen2ForCausalLM has two core APIs: + +- load_weight: loads weights from the Hugging Face official website model and injects them into the model based on the network script. + +- construct: performs inference and computing, and calls submodules to complete computing layer by layer. + As shown in the construct, the core of the model is the backbone network computing and the linear computing of the last **lm_head**, which converts the features of **hidden_size** into the vocabulary probability distribution of **vocab_size**. + +### Qwen2Model + +Qwen2Model is the main network of the Qwen2 model. It consists of two parts: the embedding layer that converts the input into features and the decoder structure of *n* Transformer layers. + +#### Embedding + +The logic of the embedding layer is simple. It obtains the feature data (which is also a part of the training weight) of **hidden_size** based on the input word ID through a gather operator. The code is as follows: + +```python +from typing import Optional, Type + +from mindspore import nn, ops, mint, Parameter, Tensor + +class VocabEmbedding(nn.Cell): + def __init__(self, config: Qwen2Config) -> None: + super().__init__() + + self.num_embeddings = config.vocab_size + self.embedding_dim = config.hidden_size + + self.gather = ops.Gather() + + self.weight = Parameter( + mint.zeros( + (self.num_embeddings, self.embedding_dim), + dtype=config.param_dtype + ), + requires_grad=False + ) + + def construct(self, input_ids: Tensor): + return self.gather(self.weight, input_ids, 0) +``` + +#### DecoderLayer + +DecoderLayer is the core computing unit of the Transformer network. Most of the computing operations are performed at this layer. As shown in the Qwen2 network structure diagram, the network layers include RoPE, Attention, and MLP. To facilitate development, these network layers are constructed first. + +- **RoPE** + + The rotary position embedding (RoPE) operator is used to enhance the Attention mechanism's capability to perceive the distance between words by adding positional encoding information to the features of the query and key. Due to the features of RoPE, the result can be pre-computed and directly obtained by querying the table, thereby achieving efficient computation. This can be implemented using the gather and the RoPE operators. For details about the calculation method, see the related documents of RoPE. + + ```python + import numpy as np + from typing import Optional, Type + + from mindspore import nn, ops, mint, Parameter, Tensor + + class Qwen2RotaryEmbedding(nn.Cell): + def __init__(self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, dtype: Optional[Type]) -> None: + super().__init__() + + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.dtype = dtype + + # format 2 is neox style + self.rotary_embedding_op = ops.ApplyRotaryPosEmb(2) + self.gather = ops.Gather() + + self.freqs_cos, self.freqs_sin = self._compute_cos_sin_cache() + + def _compute_inv_freq(self) -> Tensor: + freqs_base = mint.arange(0, self.rotary_dim, 2).astype(np.float32) + freqs = 1.0 / (self.base ** (freqs_base / self.rotary_dim)) + return freqs + + def _compute_cos_sin_cache(self) -> Tuple[Tensor, Tensor]: + freqs = self._compute_inv_freq() + t = np.arange(0, self.max_position_embeddings, 1).astype(np.float32) + freqs = np.outer(t, freqs) + emb = np.concatenate((freqs, freqs), axis=1) + freqs_cos = np.cos(emb) + freqs_sin = np.sin(emb) + + freqs_cos = Tensor(freqs_cos, dtype=self.dtype) + freqs_sin = Tensor(freqs_sin, dtype=self.dtype) + return freqs_cos, freqs_sin + + def construct(self, positions: Tensor, query: Tensor, key: Tensor, batch_valid_length: Tensor, is_prefill: bool): + query = query.contiguous() + key = key.contiguous() + + if is_prefill: + freqs_cos = self.freqs_cos + freqs_sin = self.freqs_sin + else: + freqs_cos = self.gather(self.freqs_cos, positions.view(-1), 0) + freqs_sin = self.gather(self.freqs_sin, positions.view(-1), 0) + + return self.rotary_embedding_op(query, key, freqs_cos, freqs_sin, batch_valid_length) + ``` + +- **Attention** + + An attention layer consists of multiple Linear and RoPE operators, and attention score calculation. MindSpore provides two fusion operators, FlashAttention and PagedAttention, to enhance the inference performance of attention score calculation. + + However, because these native operators are oriented to multiple scenarios and the input is complex, they are encapsulated here to simplify the usage. For details about the code, see the following: + + ```python + import numpy as np + from typing import Optional, Type + + from mindspore import nn, ops, mint, Parameter, Tensor + + class FlashAttention(nn.Cell): + def __init__(self, scale: float, num_heads: int) -> None: + super().__init__() + + input_layout = "TH" + scale = scale + pre_tokens = 2147483647 + next_tokens = 2147483647 + self.flash_attention = ops.operations.nn_ops.FlashAttentionScore(head_num=num_heads, + scale_value=scale, + pre_tokens=pre_tokens, + next_tokens=next_tokens, + input_layout=input_layout) + + def construct(self, q: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor, batch_valid_length: Tensor) -> Tensor: + _, _, _, output = self.flash_attention( + q, + k, + v, + None, + None, + None, + attn_mask, + None, + batch_valid_length, + batch_valid_length + ) + return output + + + class PagedAttention(nn.Cell): + def __init__(self, head_num: int, scale: float, num_kv_heads: int) -> None: + super().__init__() + + self.head_num = head_num + self.num_kv_heads = num_kv_heads + + self.paged_attention = ops.auto_generate.PagedAttention( + head_num=head_num, + scale_value=scale, + kv_head_num=num_kv_heads + ) + + def construct(self, q: Tensor, k_cache: Tensor, v_cache: Tensor, + block_tables: Tensor, batch_valid_length: Tensor, + attn_mask: Tensor, q_seq_lens: Tensor) -> Tensor: + output = self.paged_attention(q, k_cache, v_cache, block_tables, batch_valid_length, None, None, attn_mask, q_seq_lens) + return output + ``` + + The code of the attention layer may be implemented by using the constructed network layer. For details about the code, see the following: + + ```python + import numpy as np + from typing import Optional, Type + + from mindspore import nn, ops, mint, Parameter, Tensor + + + class Qwen2Attention(nn.Cell): + def __init__(self, config: Qwen2Config) -> None: + super().__init__() + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim =config.hidden_size // self.num_heads + self.q_size = self.head_dim * self.num_heads + self.kv_size = self.head_dim * self.num_kv_heads + self.scaling = float(self.head_dim ** -0.5) + self.rope_theta = int(config.rope_theta) + self.param_dtype = config.param_dtype + self.max_position = config.max_position_embeddings + + self.flash_attn = FlashAttention(self.scaling, self.num_heads) + self.paged_attn = PagedAttention(self.num_heads, self.scaling, self.num_kv_heads) + self.reshape_and_cache = ops.auto_generate.ReshapeAndCache() + + self.q_proj = Qwen2Linear( + input_size=self.hidden_size, + output_size=self.q_size, + param_dtype=self.param_dtype, + enable_bias=True + ) + self.k_proj = Qwen2Linear( + input_size=self.hidden_size, + output_size=self.kv_size, + param_dtype=self.param_dtype, + enable_bias=True + ) + self.v_proj = Qwen2Linear( + input_size=self.hidden_size, + output_size=self.kv_size, + param_dtype=self.param_dtype, + enable_bias=True + ) + self.o_proj = Qwen2Linear( + input_size=self.q_size, + output_size=self.hidden_size, + param_dtype=self.param_dtype, + enable_bias=False + ) + + self.rotary_emb = Qwen2RotaryEmbedding( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position_embeddings=self.max_position, + base=self.rope_theta, + dtype=self.param_dtype + ) + + def construct(self, hidden_state: Tensor, positions: Tensor, batch_valid_length: Tensor, + is_prefill: bool, layer_idx: int, k_cache: Tensor, v_cache: Tensor, + slot_mapping: Tensor, block_tables: Tensor, attn_mask: Tensor, + q_seq_lens: Tensor) -> Tensor: + bs, seq_len, hidden_dim = hidden_state.shape + + q = self.q_proj(hidden_state).view(-1, self.q_size) + k = self.k_proj(hidden_state).view(-1, self.kv_size) + v = self.v_proj(hidden_state).view(-1, self.kv_size) + + q, k = self.rotary_emb( + positions, + q, + k, + batch_valid_length, + is_prefill + ) + + k = k.contiguous() + v = v.contiguous() + + cache_out = self.reshape_and_cache( + k, + v, + k_cache, + v_cache, + slot_mapping + ) + q = ops.depend(q, cache_out) + + if is_prefill: + attn_output = self.flash_attn( + q, + k, + v, + attn_mask, + batch_valid_length + ) + else: + attn_output = self.paged_attn( + q, + k_cache, + v_cache, + block_tables, + batch_valid_length, + attn_mask, + q_seq_lens + ) + + output = self.o_proj(attn_output).view(bs, seq_len, -1) + return output + ``` + +- **MLP** + + An MLP layer, consisting of multiple Linear operators and an activation function (usually silu), is responsible for implementing non-linear computation of the network. The MLP layer can project problems to multiple non-linear spaces, thereby enhancing network capabilities. For details about the implementation, see the following code: + + ```python + import numpy as np + from typing import Optional, Type + + from mindspore import nn, ops, mint, Parameter, Tensor + + class Qwen2MLP(nn.Cell): + def __init__(self, config: Qwen2Config) -> None: + super().__init__() + + self.up_proj = Qwen2Linear( + input_size=config.hidden_size, + output_size=config.intermediate_size, + param_dtype=config.param_dtype, + enable_bias=False + ) + self.gate_proj = Qwen2Linear( + input_size=config.hidden_size, + output_size=config.intermediate_size, + param_dtype=config.param_dtype, + enable_bias=False + ) + self.down_proj = Qwen2Linear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + param_dtype=config.param_dtype, + enable_bias=False + ) + self.act_fn = ops.silu + + def construct(self, x: Tensor) -> Tensor: + output = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return output + ``` + +DecoderLayer may be constructed as follows by referring to the preceding network layer: + +```python +from typing import Tuple +from mindspore import nn, Tensor + +class Qwen2DecoderLayer(nn.Cell): + def __init__(self, config: Qwen2Config) -> None: + super().__init__() + + self.hidden_size = config.hidden_size + + self.self_attn = Qwen2Attention(config=config) + self.mlp = Qwen2MLP(config=config) + self.input_layernorm = RmsNorm(config=config) + self.post_attention_layernorm = RmsNorm(config=config) + + def construct(self, hidden_state: Tensor, residual: Tensor, positions: Tensor, + batch_valid_length: Tensor, is_prefill: bool, layer_idx: int, + k_cache: Tensor, v_cache: Tensor, slot_mapping: Tensor, + block_tables: Tensor, attn_mask: Tensor, q_seq_lens: Tensor) -> Tuple[Tensor, Tensor]: + if residual is None: + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + else: + hidden_state, residual = self.input_layernorm(hidden_state, residual) + + hidden_state = self.self_attn(hidden_state, positions, batch_valid_length, is_prefill, + layer_idx, k_cache, v_cache, slot_mapping, block_tables, + attn_mask, q_seq_lens) + hidden_state, residual = self.post_attention_layernorm(hidden_state, residual) + hidden_state = self.mlp(hidden_state) + + return hidden_state, residual +``` + +#### Model + +After the embedding and decoder layers are constructed, you can construct the Model class by referring to the following code: + +```python +from mindspore import nn, ops, mint, Parameter, Tensor + +class Qwen2Model(nn.Cell): + def __init__(self, config: Qwen2Config) -> None: + super().__init__() + + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + + self.embed_tokens = VocabEmbedding(config=config) + self.layers = nn.CellList() + for i in range(config.num_hidden_layers): + layer = Qwen2DecoderLayer(config=config) + self.layers.append(layer) + self.norm = RmsNorm(config=config) + + def construct(self, input_ids: Tensor, positions: Tensor, batch_valid_length: Tensor, + is_prefill: bool, k_caches: List[Tensor], v_caches: List[Tensor], + slot_mapping: Tensor, block_tables: Tensor, attn_mask: Tensor, + q_seq_lens: Tensor) -> Tensor: + hidden_state = self.embed_tokens(input_ids) + residual = None + + for i in range(self.num_hidden_layers): + layer = self.layers[i] + hidden_state, residual = layer(hidden_state, residual, positions, batch_valid_length, + is_prefill, i, k_caches[i], v_caches[i], slot_mapping, + block_tables, attn_mask, q_seq_lens) + + hidden_state, _ = self.norm(hidden_state, residual) + + return hidden_state +``` + +### KVCacheManager + +Since KVCache is usually used to optimize LLMs, to use KVCache with FlashAttention and lashPagedAttention provided by MindSpore, some parameters need to be specified additionally, including: + +- **k_cache & v_cache**: The kv_cache object can be considered as a cache table, which is used to store the keys and values in the previous iteration. In the next iteration, these values can be directly read, avoiding repeated computation of the keys and values of the first *n* words, thereby improving performance. + +- **block_tables & slot_mapping**: PagedAttention stores KVCache by block using a mechanism similar to paging, so that the same words can be concentrated in the same block, thereby improving graphics memory utilization. + +According to the preceding description, these parameters can be encapsulated in a management class. The code can be referenced as follows: + +```python +import math +from collections import deque +from mindspore import nn, ops, mint, Parameter, Tensor, mutable + +class CacheManager: + def __init__(self, config: Qwen2Config, block_num: int, block_size: int, batch_size: int) -> None: + self.block_num = block_num + self.block_size = block_size + self.batch_size = batch_size + + head_dim = config.hidden_size // config.num_attention_heads + + self.k_caches = mutable([ops.zeros((block_num, block_size, config.num_key_value_heads, head_dim), dtype=config.param_dtype) for _ in range(config.num_hidden_layers)]) + self.v_caches = mutable([ops.zeros((block_num, block_size, config.num_key_value_heads, head_dim), dtype=config.param_dtype) for _ in range(config.num_hidden_layers)]) + self.block_tables = [[] for _ in range(batch_size)] + self.acc_slot_mapping = [[] for _ in range(batch_size)] + self.free_block_ids = deque(range(block_num)) + + def step(self, start_pos_idx: int, token_num_per_batch: int) -> Tuple[Tensor, Tensor]: + for i in range(self.batch_size): + block_table = self.block_tables[i] + total_block_num = math.ceil((start_pos_idx + token_num_per_batch) / self.block_size) + now_block_num = len(block_table) + for _ in range(total_block_num - now_block_num): + block_id = self.free_block_ids.popleft() + block_table.append(block_id) + start_slot_id = block_id * self.block_size + self.acc_slot_mapping[i].extend(list(range(start_slot_id, start_slot_id + self.block_size))) + + + now_block_tables = Tensor(self.block_tables, dtype=dtype.int32) + now_slot_mapping = Tensor([self.acc_slot_mapping[i][start_pos_idx: start_pos_idx + token_num_per_batch] + for i in range(self.batch_size)], dtype=dtype.int32).view(-1) + + return now_block_tables, now_slot_mapping +``` + +### Sampler + +After the backbone network is computed, the network output is a vocabulary with **shape** in the range [*batch_size*,*vocab_size*], which indicates the probability distribution of the next word in multiple inference requests in the batch. You need to select a word from the vocabulary as the final result. To simplify the selection and eliminate randomness, you need to select the word with the maximum probability as the output each time, that is, perform argmax computing. The following is a code example: + +```python +from mindspore import Tensor + +def sample(logits: Tensor) -> Tensor: + next_token = logits.argmax(axis=-1, keepdims=True) + return next_token +``` + +## Converting Dynamic Graphs to Static Graphs + +MindSpore can convert dynamic graphs to static graphs using JIT to improve inference performance. In terms of code implementation, you can use the following simple decorator for conversion: + +```python +from mindspore import nn, ops, mint, Parameter, Tensor, jit + + +class Qwen2Model(nn.Cell): + def __init__(self, config: Qwen2Config) -> None: + super().__init__() + + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + + self.embed_tokens = VocabEmbedding(config=config) + self.layers = nn.CellList() + for i in range(config.num_hidden_layers): + layer = Qwen2DecoderLayer(config=config) + self.layers.append(layer) + self.norm = RmsNorm(config=config) + + @jit(jit_level="O0", infer_boost="on") + def construct(self, input_ids: Tensor, positions: Tensor, batch_valid_length: Tensor, + is_prefill: bool, k_caches: List[Tensor], v_caches: List[Tensor], + slot_mapping: Tensor, block_tables: Tensor, attn_mask: Tensor, + q_seq_lens: Tensor) -> Tensor: + hidden_state = self.embed_tokens(input_ids) + residual = None + + for i in range(self.num_hidden_layers): + layer = self.layers[i] + hidden_state, residual = layer(hidden_state, residual, positions, batch_valid_length, + is_prefill, i, k_caches[i], v_caches[i], slot_mapping, + block_tables, attn_mask, q_seq_lens) + + hidden_state, _ = self.norm(hidden_state, residual) + + return hidden_state +``` + +Add the mindspore.jit decorator to the construct method of nn.Cell to execute the computation of the cell in static graph mode. The parameters are described as follows: + +- **jit_level**: specifies the compilation level. Currently, MindSpore inference supports O0 and O1 levels (some operator fusion optimization is involved). + +- **infer_boost**: enables inference acceleration optimization. After this option is enabled, some scheduling optimization and stream optimization are performed during runtime to improve inference performance. + +In addition, due to the limitations of the static graph mode of MindSpore, dynamic-to-static conversion may fail in some scenarios. The following lists some common causes: + +- **setattrs usage**: The setattrs syntax of Python is not supported during MindSpore graph capture. Therefore, parameters cannot be encapsulated using an encapsulation class. For example, Qwen2ModelInput in the preceding example cannot be directly passed to Qwen2Model whose graph is converted to a static graph. Otherwise, the static graph execution fails. + +- **List value**: If there are list parameters when the graph is converted to a static graph, the parameters must be wrapped by mutable to ensure that MindSpore can correctly process the parameters, for example, **k_caches** and **v_caches** in the preceding example. Otherwise, the fallback to Python is triggered, which affects the inference performance. In some scenarios, the computation may fail. + +- **Graph input name**: If the PagedAttention operator of MindSpore is used, the two graph inputs must be named **batch_valid_length** and **q_seq_lens**. Otherwise, the PagedAttention operator fails to be initialized. + +If you plan to use static graph inference in the future when developing models with MindSpore, you are advised to pay attention to the preceding limitations during dynamic graph development and debugging to avoid extra costs in subsequent migration and debugging. diff --git a/tutorials/source_en/model_infer/ms_infer/ms_infer_parallel_infer.md b/tutorials/source_en/model_infer/ms_infer/ms_infer_parallel_infer.md new file mode 100644 index 0000000000000000000000000000000000000000..45b6144a7d105605438004ed23de46b16458aef9 --- /dev/null +++ b/tutorials/source_en/model_infer/ms_infer/ms_infer_parallel_infer.md @@ -0,0 +1,941 @@ +# Building a Parallel LLM Network + +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/tutorials/source_en/model_infer/ms_infer/ms_infer_parallel_infer.md) + +As model sizes continue to expand, the computing resources required by LLMs, particularly graphics memory, are growing exponentially. For example, the Qwen2-72B requires approximately 144 GB of graphics memory at half-precision (FP16). + +In addition, the increasing sequence length of LLMs places immense pressure on graphics memory. Graphics memory not only affects model loading, but also limits the batch size. A small batch size may reduce the inference efficiency, which in turn affects the throughput of the entire system. + +The pressure on graphics memory makes it challenging for a single device to complete inference tasks within a reasonable time frame, and parallel computing has become a key strategy to address this challenge. This section uses the network structure of a common LLM as an example to analyze the model parallelism solution. + +## Model Parallelism Requirement Analysis + +Before performing model sharding and parallelism, you need to analyze the parallelism based on the model structure to determine which layers can be parallelized and how to divide the model to achieve better performance acceleration. To achieve better acceleration, the parallelized part needs to be computed separately, minimizing the impact on other parts. The following uses the Qwen2 model structure as an example to analyze the parallelism of the main network structure: + +- **Embedding**: The embedding layer is actually a gather operation and can be parallelized properly regardless of the sharding dimension (**hidden_dim** or **num_embeddings**). Because **all_reduce** (reducing overheads of data arrangement) can be better performed based on **num_embedding**, sharding is performed based on the **num_embeddings** dimension. + +- **Attention**: The Qwen2 model uses the attention computation method of GQA, that is, multiple independent attention computations. Therefore, the query, key, and value can be parallelized separately by column. However, the number of shards must be exactly divided by the number of attention heads. + +- **MLP**: The MLP layer is actually a matrix multiplication of two Linear layers, which can be sharded by block. + +- **RmsNorm&Add**: RmsNorm needs to normalize a row of data, which requires global information. Therefore, RmsNorm cannot be effectively parallelized. In this case, you need to use all_reduce to summarize data and then compute data. In addition, Add and RmsNorm usually used together and cannot be sharded. + +- **LMHead**: The LMHead layer is actually a Linear layer. The input **shape** is usually (*batch_size*, *hidden_size*) multiplied by (*hidden_size*, *vocab_size*). You can perform sharding by **vocab_size** and combine them using all_gather for acceleration. + +The following figure shows the execution of one Qwen2 layer with a parallelism degree of 2. + +![matmul1](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/model_infer/ms_infer/images/llm_qwen2_parallel_split.png) + +As shown in the figure, RmsNorm cannot be sharded. Therefore, an AllReduce operator needs to be added to the network before each RmsNorm computing to synchronize the computing results of each subprocess. The result after RmsNorm is usually **hidden_states**. Therefore, the result can be sharded by column-wise Linear and allocated to each subprocess for computing and then normalized by RowLinear. + +## Model Module Parallelism Solution + +The Linear layer is the main network layer for sharding, and its core is MatMul (matrix computation). Therefore, matrix sharding and computation is the most important part of model parallelism. + +### Basic MatMul Module + +![matmul1](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/model_infer/ms_infer/images/gmm.png) + +![matmul2](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/model_infer/ms_infer/images/matmul.png) + +In LLM computations, matrix multiplication (MatMul) accounts for a significant portion of both weight and computation workload. MatMul exhibits both column-wise parallelism and row-wise parallelism. + +![Column-wise Parallelism](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/model_infer/ms_infer/images/column.png) + +![Row-wise Parallelism](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/model_infer/ms_infer/images/row.png) + +Starting with the original implementation of `nn.Dense` in MindSpore, we can build implementations for both column-wise and row-wise MatMul. + +1. Creation and management of communication domains and management of LLM configurations + + Build the `CommunicationHelper` class to manage the model parallel domain. + + ```python + from mindspore.communication import create_group, get_group_size, get_rank + ``` + + ```python + class CommunicationHelper: + def __init__(self, group_name, size): + self.group_name = group_name + self.size = size + self.rank_list = [i for i in range(size)] + + def create_tensor_model_parallel_group(self): + create_group(group=self.group_name, rank_ids=self.rank_list) + + def get_tensor_model_parallel_group_size(self): + return get_group_size(group=self.group_name) + + def get_tensor_model_parallel_group_rank(self): + return get_rank(group=self.group_name) + + def get_tensor_model_parallel_group(self): + return self.group_name + ``` + + Build `ConfigHelper` to manage and configure LLM parameters. + + ```python + class ConfigHelper: + def __init__(self, + vocab_size, + hidden_size, + ffn_hidden_size, + num_layers, + batch_size, + seq_length, dtype, + num_heads, + has_bias=False): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.num_layers = num_layers + self.batch_size = batch_size + self.seq_length = seq_length + self.dtype = dtype + self.num_heads = num_heads + self.has_bias = has_bias + ``` + +2. Column-wise MatMul + + The `ColumnParallelLinear` class computes the weight shape after sharding and initializes the weights based on the number of devices for model parallelism. Column-wise parallelism divides `out_channels`. In the model forward propagation process, the MatMul is called to compute the result after parallelism. You can perform `AllGather` on the parallelized result to obtain the complete output. + + The MindSpore training and inference integrated framework supports enabling **infer_boost**. This parameter activates the high-performance self-developed operator library within the MindSpore framework. To enable this mode, you need to: + + 1. Set variables. + + ```python + from mindspore import set_context + set_context(jit_config={"jit_level": 'O0', "infer_boost": 'on'}) + ``` + + 2. Set system environment variables. + + ```bash + export ASCEND_HOME_PATH={$ascend_custom_path} + ``` + + For example, if there are 2 devices for model parallelism, set environment variables, initialize the communication group, and configure the model parameter **config** as follows: + + ```python + from mindspore import nn, Parameter, ops, Tensor + from mindspore.common import dtype as mstype + from mindspore.communication import init + from mindspore.common.initializer import initializer + import numpy as np + + from mindspore import set_context + set_context(jit_config={"jit_level": 'O0', "infer_boost": 'on'}) + + TP_GROUP_NAME='tp' + TP_SIZE = 2 + COMMUN_HELPER = CommunicationHelper(group_name=TP_GROUP_NAME, size=TP_SIZE) + + init() + COMMUN_HELPER.create_tensor_model_parallel_group() + + config = ConfigHelper(batch_size=64, + vocab_size=32000, + num_layers=4, + seq_length=2048, + hidden_size=1024, + ffn_hidden_size=4096, + dtype=mstype.float16, + num_heads=8, + has_bias=False) + ``` + + The column-wise MatMul module is implemented as follows: + + ```python + class ColumnParallelLinear(nn.Cell): + def __init__(self, + in_channels, + out_channels, + weight_init=None, + bias_init=None, + has_bias=True, + dtype=mstype.float32): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.has_bias = has_bias + self.tensor_parallel_group_size = COMMUN_HELPER.get_tensor_model_parallel_group_size() + self.out_channels_per_partition = out_channels // self.tensor_parallel_group_size + self.dtype = dtype + weight_shape = (self.out_channels_per_partition, self.in_channels) + self.weight = Parameter(initializer(weight_init, weight_shape, self.dtype), name="weight") + if self.has_bias: + self.bias = Parameter(initializer(bias_init, (self.out_channels_per_partition), self.dtype), name="bias") + self.bias_add = ops.Add() + self.matmul = ops.BatchMatMul(transpose_b=True) + self.cast = ops.Cast() + + def construct(self, x): + origin_dtype = x.dtype + x = self.cast(x, self.dtype) + out = self.matmul(x, self.weight) + if self.has_bias: + out = self.bias_add( + out, self.cast(self.bias, self.dtype) + ) + out = self.cast(out, origin_dtype) + return out + ``` + + The output of column-wise MatMul is parallelized. To obtain a complete output, use `GatherLastDim`. + + ```python + class GatherLastDim(nn.Cell): + def __init__(self): + super().__init__() + self.all_gather = ops.AllGather(group=COMMUN_HELPER.get_tensor_model_parallel_group()) + self.world_size = COMMUN_HELPER.get_tensor_model_parallel_group_size() + self.split = ops.Split(axis=0, output_num=self.world_size) + + def construct(self, input_): + output = self.all_gather(input_) + tensor_list = self.split(output) + output = ops.cat(tensor_list, axis=-1) + return output + ``` + + Inference of column-wise MatMul: + + ```python + column_parallel_linear = ColumnParallelLinear(in_channels=config.hidden_size, + out_channels=config.hidden_size, + weight_init='normal', + dtype=config.dtype, + has_bias=False) + input_x = Tensor(np.random.randn(config.batch_size, config.seq_length, config.hidden_size).astype(np.float32)) + out_parallel = column_parallel_linear(input_x) + print(out_parallel.shape) + + gather_last_dim = GatherLastDim() + out = gather_last_dim(out_parallel) + print(out.shape) + ``` + +3. Row-wise MatMul + + Similar to column-wise MatMul, `RowParallelLinear` shards weights based on the size of the model parallelism domains. During initialization, sharding is performed by row, that is, sharding by `in_channels`. In the model forward propagation process, after the MatMul of the inputs and weights, `AllReduce` needs to be performed on the results of all `devices`. + + The row-wise MatMul module is implemented as follows: + + ```python + class RowParallelLinear(nn.Cell): + def __init__(self, + in_channels, + out_channels, + weight_init='normal', + bias_init=None, + has_bias=True, + dtype=mstype.float32): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.has_bias = has_bias + self.tensor_parallel_group_size = COMMUN_HELPER.get_tensor_model_parallel_group_size() + self.in_channels_per_partition = in_channels // self.tensor_parallel_group_size + self.dtype = dtype + weight_shape = (self.out_channels, self.in_channels_per_partition) + self.weight = Parameter(initializer(weight_init, weight_shape, self.dtype), name="weight") + if self.has_bias: + self.bias = Parameter(initializer(bias_init, (self.in_channels_per_partition), self.dtype), name="bias") + self.bias_add = ops.Add() + self.bmm = ops.BatchMatMul(transpose_b=True) + self.all_reduce = ops.AllReduce(group=COMMUN_HELPER.get_tensor_model_parallel_group()) + self.cast = ops.Cast() + + def construct(self, x): + origin_dtype = x.dtype + x = self.cast(x, self.dtype) + output_parallel = self.bmm(x, self.weight) + if self.has_bias: + output_parallel = self.bias_add(output_parallel, self.cast(self.bias, self.dtype)) + out = self.all_reduce(output_parallel) + out = self.cast(out, origin_dtype) + return out + ``` + + Inference of row-wise MatMul: + + ```python + row_parallel_linear = RowParallelLinear(in_channels=config.hidden_size, + out_channels=config.hidden_size, + weight_init='normal', + dtype=config.dtype, + has_bias=False) + out = row_parallel_linear(out_parallel) + print(out.shape) + ``` + +4. Embedding + + In addition to MatMul, the embedding layer can also be parallelized. The embedding weights can be sharded across multiple devices, with each device responsible for mapping a different range of token IDs. + + ![embedding2](./images/embedding2.png) + + Based on nn.Embedding, build an embedding layer for model parallelism. + + ```python + class VocabParallelEmbedding(nn.Cell): + def __init__(self, + num_embeddings, + embedding_dim, + init_method="normal", + init_type=mstype.float32): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.tensor_model_parallel_size = COMMUN_HELPER.get_tensor_model_parallel_group_size() + per_partition_vocab_size = self.num_embeddings // self.tensor_model_parallel_size + self.vocab_start_index = COMMUN_HELPER.get_tensor_model_parallel_group_rank() * per_partition_vocab_size + self.vocab_end_index = self.vocab_start_index + per_partition_vocab_size + self.num_embeddings_per_partition = ( + self.vocab_end_index - self.vocab_start_index + ) + self.embedding_weight = Parameter( + initializer( + init=init_method, + shape=(self.num_embeddings_per_partition, self.embedding_dim), + dtype=init_type, + ), + name="embedding_weight", + ) + self.all_reduce = ops.AllReduce(group=COMMUN_HELPER.get_tensor_model_parallel_group()) + self.max_index_per_partition = Tensor(self.num_embeddings_per_partition - 1, dtype=mstype.int32) + self.expand_dims = ops.ExpandDims() + self.gather = ops.Gather() + self.sub = ops.Sub() + self.relu = ops.ReLU() + self.minimum = ops.Minimum() + self.eq = ops.Equal() + self.mul = ops.Mul() + + def construct(self, x): + displaced_x = self.sub(x, self.vocab_start_index) + down_truncated_x = self.relu(displaced_x) + truncated_x = self.minimum(down_truncated_x, self.max_index_per_partition) + input_mask = self.eq(displaced_x, truncated_x) + input_mask = self.expand_dims(input_mask, -1) + output_parallel = self.gather(self.embedding_weight, truncated_x, 0) + output_parallel = self.mul(output_parallel, input_mask) + output = self.all_reduce(output_parallel) + return output + ``` + + Inference of parallel embedding: + + ```python + input_ids = np.random.randint(0, config.vocab_size, size=(config.batch_size, config.seq_length), dtype=np.int32) + input_ids = Tensor(input_ids) + + vocab_parallel_embedding = VocabParallelEmbedding(num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size) + embedding_output = vocab_parallel_embedding(input_ids) + print(embedding_output.shape) + ``` + +### TransformerModel Parallel Adaptation + +It can be seen that the tensor is processed sequentially. First, it passes through the `ColumnParallelLinear` column-wise MatMul to obtain the parallelized results. Then, it is input to the `RowParallelLinear` row-wise MatMul, resulting in the complete output of the two MatMul operations. + +![Column+Row](../../../source_zh_cn/model_infer/ms_infer/images/column+row.png) + +Based on the preceding analysis, TransformerModel can be modified to support parallelism. + +1. Attention + + Take the multi-head attention (MHA) module as an example. The attention module in the Transformer is multi-headed, and attention heads are independent of each other. Therefore, the activation value can be sharded by `hidden_size` while ensuring that a single attention head is complete. For example, assume that the number of MHA headers (`num_heads`) is 16, the number of dimensions (`head_dim`) of each header is 256, then `hidden_size` is 4096, and the number of linear in/out dimensions of Q/K/V is 4096. When `tensor_model_parallel` is set to `4` for the model parallelism, these linear results are allocated to four devices. The shape of each device is (4096,1024), indicating that each device computes 4 heads. + + ![MHA](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/model_infer/ms_infer/images/MHA.png) + + The following is an example of the Attention module code: + + ```python + class ParallelAttention(nn.Cell): + def __init__(self, config): + super().__init__() + self.tensor_model_parallel_size = COMMUN_HELPER.get_tensor_model_parallel_group_size() + self.num_heads_per_partition = config.num_heads // self.tensor_model_parallel_size + self.head_dim = config.hidden_size // config.num_heads + self.norm_factor = math.sqrt(self.head_dim) + self.q = ColumnParallelLinear(in_channels=config.hidden_size, + out_channels=config.hidden_size, + weight_init='normal', + has_bias=config.has_bias) + self.k = ColumnParallelLinear(in_channels=config.hidden_size, + out_channels=config.hidden_size, + weight_init='normal', + dtype=config.dtype, + has_bias=config.has_bias) + self.v = ColumnParallelLinear(in_channels=config.hidden_size, + out_channels=config.hidden_size, + weight_init='normal', + dtype=config.dtype, + has_bias=config.has_bias) + self.flash_attention = ops.operations.nn_ops.FlashAttentionScore(head_num=self.num_heads_per_partition, + scale_value=1.0/self.norm_factor, + next_tokens=0) + self.out = RowParallelLinear(in_channels=config.hidden_size, + out_channels=config.hidden_size, + weight_init='normal', + dtype=config.dtype, + has_bias=config.has_bias) + + def construct(self, x, mask): + query = self.q(x) + key = self.k(x) + value = self.v(x) + _, _, _, context_layer = self.flash_attention(query, key, value, attn_mask=mask) + output = self.out(context_layer) + return output + ``` + +2. MLP + + The MLP module is two fully-connected layers, which can also be processed by parallel MatMul. The code is as follows: + + ```python + class ParallelMLP(nn.Cell): + def __init__(self, config): + super().__init__() + self.w1 = ColumnParallelLinear(in_channels=config.hidden_size, + out_channels=config.ffn_hidden_size, + weight_init='normal', + dtype=config.dtype, + has_bias=config.has_bias) + self.w2 = RowParallelLinear(in_channels=config.ffn_hidden_size, + out_channels=config.hidden_size, + weight_init='normal', + dtype=config.dtype, + has_bias=config.has_bias) + self.act_func = nn.SiLU() + self.mul = ops.Mul() + + def construct(self, x): + x = self.w1(x) + x = self.act_func(x) + output = self.w2(x) + return output + ``` + +3. TransformerLayer + + TransformerLayer consists of Attention and MLP. Since there are no single operators that can be parallelized, you only need to pass the parallel parameters to Attention and MLP. + + ```python + class ParallelTransformerLayer(nn.Cell): + def __init__(self, config): + super().__init__() + self.attention = ParallelAttention(config=config) + self.feed_forward = ParallelMLP(config=config) + self.attention_norm = RMSNorm(dim=config.hidden_size, dtype=config.dtype) + self.ffn_norm = RMSNorm(dim=config.hidden_size, dtype=config.dtype) + self.add = ops.Add() + + def construct(self, x, mask): + norm_output = self.attention_norm(x) + attention_output = self.attention(norm_output, mask) + norm_input = self.add(x, attention_output) + norm_output = self.ffn_norm(norm_input) + mlp_output = self.feed_forward(norm_output) + output = self.add(norm_input, mlp_output) + return output + ``` + +4. TransformerModel + + ```python + class ParallelTransformer(nn.Cell): + def __init__(self, config): + super().__init__() + self.embedding = VocabParallelEmbedding(num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + init_method='normal', + init_type=config.dtype) + self.layers = nn.CellList() + self.num_layers = config.num_layers + for _ in range(config.num_layers): + layer = ParallelTransformerLayer(config=config) + self.layers.append(layer) + self.norm_out = RMSNorm(dim=config.hidden_size, dtype=config.dtype) + + def construct(self, x, mask): + hidden_state = self.embedding(x) + for i in range(self.num_layers): + hidden_state = self.layers[i](hidden_state, mask) + hidden_state = self.norm_out(hidden_state) + return hidden_state + ``` + +For details about the end-to-end LLM code project, see the [model_dev.py](https://gitee.com/mindspore/docs/blob/master/docs/sample_code/infer_code/model_dev.py) script. Run the following command to verify the code: + +```shell +msrun --worker_num 2 --local_worker_num 2 --master_port 8124 --log_dir msrun_log --join True --cluster_time_out 300 model_dev.py +``` + +## Practice: Qwen2 Model Parallel Reconstruction + +This section describes how to adapt the Qwen2 LLM developed in [Building an LLM Inference Network from Scratch](./ms_infer_network_develop.md) to parallel processing. Based on the preceding analysis, parallel adaptation can be divided into the following two main steps: + +1. **Model network adaptation**: Based on the preceding parallelism solution, parallelize the network layers in the model and allocate the computation workloads to multiple cards. + +2. **Model weight adaptation**: Modify the weights accordingly when the model weights are loaded because the shape of the weights in Linear changes after parallel sharding. + +To simplify the scenario, this section shards only the Linear layer of the Qwen2 model with a parallelism degree of 2. Currently, the sharding of the embedding layer is not involved. + +### Establishing a Communication Group + +Before reconstructing the model, you need to use the communication module of MindSpore to establish a communication group to implement subsequent communication operations. This function can be directly implemented using the CommunicationHelper class described above. The following code can be used to implement this function: + +```python +from mindspore.communication import create_group, get_group_size, get_rank, init + +class CommunicationHelper: + def __init__(self, group_name: str, size: int) -> None: + self.group_name = group_name + self.size = size + self.rank_list = [i for i in range(size)] + + def create_tensor_model_parallel_group(self): + create_group(group=self.group_name, rank_ids=self.rank_list) + + def get_tensor_model_parallel_group_size(self): + return get_group_size(group=self.group_name) + + def get_tensor_model_parallel_group_rank(self): + return get_rank(group=self.group_name) + + def get_tensor_model_parallel_group(self): + return self.group_name + +COMMON_HELPER = None + +def init_communication(): + TP+GROUP_NAME = "tp" + TP_SIZE = 2 + + global COMMON_HELPER + COMMON_HELPER = CommunicationHelper(group_name=TP_GROUP_NAME, size=TP_SIZE) + init() + COMMON_HELPER.create_tensor_model_parallel_group() +``` + +### Model Sharding and Parallelism + +This solution mainly performs sharding and parallelism on the Linear layer. Therefore, the Linear layer is modified mainly. In the implementation, Qwen2Linear needs to be changed to Qwen2ColParallelLinear and Qwen2RowParallelLinear, which correspond to the Linear layer of column sharding and row sharding, respectively. For details, see the following code: + +```diff +from typing import Optional, Type, Tuple + +from mindspore import nn, ops, mint, Parameter, Tensor + +class Qwen2ColParallelLinear(nn.Cell): + def __init__(self, input_size: int, output_size: int, param_dtype: Optional[Type], bias: bool) -> None: + super().__init__() + ++ self.tp_size = COMMON_HELPER.get_tensor_model_parallel_group_size() + self.param_dtype = param_dtype + self.input_size = input_size +- self.output_size = output_size ++ self.output_size = output_size // self.tp_size + self.enable_bias = bias + + self.matmul = ops.MatMul(transpose_b=True) + self.weight = Parameter( + mint.zeros( + (self.output_size, self.input_size), + dtype=self.param_dtype + ), requires_grad=False + ) + + if self.enable_bias: + self.bias_add = ops.Add() + self.bias = Parameter( + mint.zeros(self.output_size, dtype=self.param_dtype) + ) + + def construct(self, input: Tensor) -> Tuple[Tensor, bool]: + origin_shape = input.shape + x = self.matmul(input.view(-1, origin_shape[-1]), self.weight) + if self.enable_bias: + x = self.bias_add(x, self.bias) + return x.view(*origin_shape[:-1], -1) + + +class Qwen2RowParallelLinear(nn.Cell): + def __init__(self, input_size: int, output_size: int, param_dtype: Optional[Type], bias: bool) -> None: + super().__init__() + ++ self.tp_size = COMMON_HELPER.get_tensor_model_parallel_group_size() + self.param_dtype = param_dtype +- self.input_size = input_size ++ self.input_size = input_size // self.tp_size + self.output_size = output_size + self.enable_bias = bias + + self.matmul = ops.MatMul(transpose_b=True) + self.weight = Parameter( + mint.zeros( + (self.output_size, self.input_size), + dtype=self.param_dtype + ), requires_grad=False + ) + + if self.enable_bias: + self.bias_add = ops.Add() + self.bias = Parameter( + mint.zeros(self.output_size, dtype=self.param_dtype) + ) ++ self.all_reduce = ops.AllReduce(group=COMMON_HELPER.get_tensor_model_parallel_group()) + + def construct(self, input: Tensor) -> Tuple[Tensor, bool]: + origin_shape = input.shape + x = self.matmul(input.view(-1, origin_shape[-1]), self.weight) + if self.enable_bias: + x = self.bias_add(x, self.bias) ++ x = self.all_reduce(x) + return x.view(*origin_shape[:-1], -1) +``` + +As shown in the preceding code, the Linear reconstruction is simple. Qwen2ColParallelLinear only needs to shard the output dimension based on the parallelism degree, and Qwen2RowParallelLinear only needs to shard the input dimension based on the parallelism degree. Because all_reduce computation is required after row sharding, an all_reduce operation is added to Qwen2RowParallelLinear. + +In addition, the original Qwen2Linear layer needs to be changed to a new Linear layer based on the algorithm. Pay attention to the following three parts: + +- **Attention**: Four Linear layers are involved, including query, key, value, and output. The query, key, and value layers need to be replaced by Qwen2ColParallelLinear, and the output layer needs to be replaced by Qwen2RowParallelLinear. + +- **MLP**: Three Linear layers are involved, including gate, up, and down. The gate and up layers need to be replaced by Qwen2ColParallelLinear, and the down layer needs to be replaced by Qwen2RowParallelLinear. + +- **LMHead**: A Linear layer is involved. Since there is no row-wise Linear layer corresponding to it, the all_gather operation is required to obtain the results of multiple devices. + +You can replace the class objects to complete the following modifications and adaptations. The following lists the modified network layer implementation: + +```diff +import numpy as np +from typing import Optional, Type + +from mindspore import nn, ops, mint, Parameter, Tensor + + +class Qwen2Attention(nn.Cell): + def __init__(self, config: Qwen2Config) -> None: + super().__init__() + ++ self.tp_size = COMMON_HELPER.get_tensor_model_parallel_group_size() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim =config.hidden_size // self.num_heads + self.q_size = self.head_dim * self.num_heads + self.kv_size = self.head_dim * self.num_kv_heads + self.scaling = float(self.head_dim ** -0.5) + self.rope_theta = int(config.rope_theta) + self.param_dtype = config.param_dtype + self.max_position = config.max_position_embeddings + +- self.flash_attn = FlashAttention(self.scaling, self.num_heads) +- self.paged_attn = PagedAttention(self.num_heads, self.scaling, self.num_kv_heads) ++ self.flash_attn = FlashAttention(self.scaling, self.num_heads // self.tp_size) ++ self.paged_attn = PagedAttention(self.num_heads // self.tp_size, self.scaling, self.num_kv_heads // self.tp_size) + self.reshape_and_cache = ops.auto_generate.ReshapeAndCache() + + self.q_proj = Qwen2ColParallelLinear( + input_size=self.hidden_size, + output_size=self.q_size, + param_dtype=self.param_dtype + bias=True + ) + self.k_proj = Qwen2ColParallelLinear( + input_size=self.hidden_size, + output_size=self.kv_size, + param_dtype=self.param_dtype, + bias=True + ) + self.v_proj = Qwen2ColParallelLinear( + input_size=self.hidden_size, + output_size=self.kv_size, + param_dtype=self.param_dtype, + bias=True + ) + self.o_proj = Qwen2RowParallelLinear( + input_size=self.q_size, + output_size=self.hidden_size, + param_dtype=self.param_dtype, + bias=False + ) + + self.rotary_emb = Qwen2RotaryEmbedding( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position_embeddings=self.max_position, + base=self.rope_theta, + dtype=self.param_dtype + ) + + def construct(self, hidden_state: Tensor, positions: Tensor, batch_valid_length: Tensor, + is_prefill, bool, layer_idx: int, k_cache: Tensor, v_cache: Tensor, + slot_mapping: Tensor, block_tables: Tensor, attn_mask: Tensor, + q_seq_lens: Tensor) -> Tensor: + bs, seq_len, hidden_dim = hidden_state.shape + +- q = self.q_proj(hidden_state).view(-1, self.q_size // self.tp_size) +- k = self.k_proj(hidden_state).view(-1, self.kv_size // self.tp_size) +- v = self.v_proj(hidden_state).view(-1, self.kv_size // self.tp_size) ++ q = self.q_proj(hidden_state).view(-1, self.q_size // self.tp_size) ++ k = self.k_proj(hidden_state).view(-1, self.kv_size // self.tp_size) ++ v = self.v_proj(hidden_state).view(-1, self.kv_size // self.tp_size) + + k = k.contiguous() + v = v.contiguous() + + cache_out = self.reshape_and_cache( + k, + v, + k_cache, + v_cache, + slot_mapping + ) + q = ops.depend(q, cache_out) + + if is_prefill: + attn_output = self.flash_attn( + q, + k, + v, + attn_mask, + batch_valid_length + ) + else: + attn_output = self.paged_attn( + q, + k_cache, + v_cache, + block_tables, + batch_valid_length, + attn_mask, + q_seq_lens + ) + + output = self.o_proj(attn_output).view(bs, seq_len, -1) + return output + +class Qwen2MLP(nn.Cell): + def __init__(self, config: Qwen2Config) -> None: + super().__init__() + + self.up_proj = Qwen2ColParallelLinear( + input_size=config.hidden_size, + output_size=config.intermediate_size, + param_dtype=config.param_dtype, + bias=False + ) + self.gate_proj = Qwen2ColParallelLinear( + input_size=config.hidden_size, + output_size=config.intermediate_size, + param_dtype=config.param_dtype, + bias=False + ) + self.down_proj = Qwen2RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + param_dtype=config.param_dtype, + bias=False + ) + self.act_fn = ops.silu + + def construct(self, x: Tensor) -> Tensor: + output = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return output + ++class GatherLastDim(nn.Cell): ++ def __init__(self): ++ self.all_gather = ops.AllGather(group=COMMON_HELPER.get_tensor_model_parallel_group()) ++ self.world_size = COMMON_HELPER.get_tensor_model_parallel_group_size() ++ self.split = ops.Split(axis=0, output_num=self.world_size) ++ ++ def construct(self, input: Tensor) -> Tensor: ++ output = self.all_gather(input) ++ tensor_list = self.split(output) ++ output = ops.cat(tensor_list, axis=-1) ++ return output + +class Qwen2ForCausalLM(nn.Cell): + def __init__(self, config: Qwen2Config) -> None: + super().__init__() + + self.model = Qwen2Model(config=config) + self.lm_head = Qwen2ColParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + param_dtype=config.param_dtype, + bias=False + ) ++ self.all_gather = GatherLastDim() + + def load_weight(self, weight_path: str) -> None: + weight_dict = {} + for path in glob(weight_path + "/*.safetensors"): + weight_dict.update(ms.load_checkpoint(path, format="safetensors")) + + ms.load_param_into_net(self, weight_dict, strict_load=False) + + def construct(self, model_input: Qwen2ModelInput) -> Tensor: + hidden_state = self.model(model_input.input_ids, model_input.positions, + model_input.batch_valid_length, model_input.is_prefill, + model_input.k_caches, model_input.v_caches, model_input.slot_mapping, + model_input.block_tables, model_input.attn_mask, model_input.q_seq_len) + logits = self.lm_head(hidden_state)[:, -1] ++ logits = self.all_gather(logits) + return logits +``` + +The code implementation changes slightly. Note that the query, key, and value in attention are sharded based on the heads of attention. Therefore, the input and output dimensions of FlashAttention and PagedAttention need to be divided by the degree of parallelism to narrow down the calculation scope; in addition, ensure that the degree of parallelism can be exactly divided by the number of heads of the query, key, and value. + +### Model Weight Sharding + +The original Qwen2ForCausalLM uses the load_param_into_net function provided by MindSpore to inject weights into the model. The logic is to load the original weights. After the model is sharded, the model to be loaded also needs to be adapted, and the size needs to be changed. Processes on non-zero cards need to read data based on the offset. Therefore, the load_weight function needs to be modified to implement weight loading in parallel mode. + +You are advised to register the loading function in the weight parameter. For details, see the following code: + +```diff +from typing import Optional, Type, Tuple + +from mindspore import nn, ops, mint, Parameter, Tensor + +class Qwen2ColParallelLinear(nn.Cell): + def __init__(self, input_size: int, output_size: int, param_dtype: Optional[Type], bias: bool) -> None: + super().__init__() + + self.tp_size = COMMON_HELPER.get_tensor_model_parallel_group_size() + self.param_dtype = param_dtype + self.input_size = input_size + self.output_size = output_size // self.tp_size + self.enable_bias = bias + + self.matmul = ops.MatMul(transpose_b=True) + self.weight = Parameter( + mint.zeros( + (self.output_size, self.input_size), + dtype=self.param_dtype + ), requires_grad=False + ) ++ setattr(self.weight, "weight_load", self.weight_load) + + if self.enable_bias: + self.bias_add = ops.Add() + self.bias = Parameter( + mint.zeros(self.output_size, dtype=self.param_dtype) + ) ++ setattr(self.bias, "weight_load", self.weight_load) + + def construct(self, input: Tensor) -> Tuple[Tensor, bool]: + origin_shape = input.shape + x = self.matmul(input.view(-1, origin_shape[-1]), self.weight) + if self.enable_bias: + x = self.bias_add(x, self.bias) + return x.view(*origin_shape[:-1], -1) + ++ def weight_load(self, param: Tensor, weight: Tensor) -> None: ++ tp_rank = COMMON_HELPER.get_tensor_model_parallel_group_rank() ++ copy_dim = 0 ++ shard_size = param.shape[copy_dim] ++ start_idx = tp_rank * shard_size ++ weight = weight.narrow(copy_dim, start_idx, shard_size).contiguous() ++ ++ param.set_data(weight) ++ return None + + + +class Qwen2RowParallelLinear(nn.Cell): + def __init__(self, input_size: int, output_size: int, param_dtype: Optional[Type], bias: bool) -> None: + super().__init__() + + self.tp_size = COMMON_HELPER.get_tensor_model_parallel_group_size() + self.param_dtype = param_dtype + self.input_size = input_size // self.tp_size + self.output_size = output_size + self.enable_bias = bias + + self.matmul = ops.MatMul(transpose_b=True) + self.weight = Parameter( + mint.zeros( + (self.output_size, self.input_size), + dtype=self.param_dtype + ), requires_grad=False + ) ++ setattr(self.weight, "weight_load", self.weight_load) + + if self.enable_bias: + self.bias_add = ops.Add() + self.bias = Parameter( + mint.zeros(self.output_size, dtype=self.param_dtype) + ) ++ setattr(self.bias, "weight_load", self.weight_load) + self.all_reduce = ops.AllReduce(group=COMMON_HELPER.get_tensor_model_parallel_group()) + + def construct(self, input: Tensor) -> Tuple[Tensor, bool]: + origin_shape = input.shape + x = self.matmul(input.view(-1, origin_shape[-1]), self.weight) + if self.enable_bias: + x = self.bias_add(x, self.bias) + x = self.all_reduce(x) + return x.view(*origin_shape[:-1], -1) + ++ def weight_load(self, param: Tensor, weight: Tensor) -> None: ++ tp_rank = COMMON_HELPER.get_tensor_model_parallel_group_rank() ++ copy_dim = 1 ++ shard_size = param.shape[copy_dim] ++ start_idx = tp_rank * shard_size ++ weight = weight.narrow(copy_dim, start_idx, shard_size).contiguous() ++ ++ param.set_data(weight) ++ return None + +class Qwen2ForCausalLM(nn.Cell): + def __init__(self, config: Qwen2Config) -> None: + super().__init__() + + self.model = Qwen2Model(config=config) + self.lm_head = Qwen2ColParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + param_dtype=config.param_dtype, + bias=False + ) + self.all_gather = GatherLastDim() + + def load_weight(self, weight_path: str) -> None: + weight_dict = {} + for path in glob(weight_path + "/*.safetensors"): + weight_dict.update(ms.load_checkpoint(path, format="safetensors")) + +- ms.load_param_into_net(self, weight_dict, strict_load=False) ++ param_dict = self.parameters_dict() ++ ++ for (name, weight) in weight_dict.items(): ++ if name in param_dict: ++ param = param_dict[name] ++ if hasattr(param, "weight_load"): ++ weight_load = getattr(param, "weight_load") ++ weight_load(param, weight) ++ else: ++ param.set_data(weight) +``` + +The weight_load method is added to the network layer that requires user-defined weight loading. The user-defined weight loading method is set for the weight object by using the setattr method. During model weight loading, the corresponding parameter object is found by reading the weight mapping table, so as to update the weights. For the column-wise or row-wise Linear layer, the narrow method of Tensor is used to obtain the data with the corresponding offset. The only difference is that the sharding dimensions are different. + +### Parallel Execution + +After the model adaptation and weight adaptation are complete, you can run the following command to start multi-device execution: + +```shell +msrun --worker_num 2 --local_worker_num 2 --master_port 8124 --log_dir msrun_log --join True --cluster_time_out 300 infer_parallel.py +``` + +**infer_parallel.py** is the inference script. diff --git a/tutorials/source_en/model_infer/ms_infer/ms_infer_quantization.md b/tutorials/source_en/model_infer/ms_infer/ms_infer_quantization.md new file mode 100644 index 0000000000000000000000000000000000000000..5d20e83bcf527e264f1e56ea896c192f40e74057 --- /dev/null +++ b/tutorials/source_en/model_infer/ms_infer/ms_infer_quantization.md @@ -0,0 +1,204 @@ +# Model Quantization + +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/tutorials/source_en/model_infer/ms_infer/ms_infer_quantization.md) + +## Overview + +MindSpore is an all-scenario AI framework. When a model is deployed on the device or other lightweight devices, it may be subject to memory, power consumption, and latency. Therefore, the model needs to be compressed before deployment. + +[MindSpore Golden Stick](https://www.mindspore.cn/golden_stick/docs/en/master/index.html) provides the model compression capability of MindSpore. MindSpore Golden Stick is a set of model compression algorithms jointly designed and developed by Huawei Noah's Ark team and Huawei MindSpore team. It provides a series of model compression algorithms for MindSpore, supporting quantization modes such as A16W8, A16W4, A8W8, and KVCache. For details, see [MindSpore Golden Stick](https://www.mindspore.cn/golden_stick/docs/en/master/index.html). + +## Basic Model Quantization Process + +To help you understand the basic model quantization process of MindSpore Golden Stick, this section uses the quantization algorithm as an example to describe the basic usage. + +### Procedure + +The MindSpore Golden Stick quantization algorithm can be divided into two phases: quantization phase and deployment phase. The quantization phase is completed before deployment. The main tasks are as follows: collecting weight distribution, computing quantization parameters, quantizing weight data, and inserting dequantization nodes. The deployment phase refers to the process of using the MindSpore framework to perform inference on the quantized model in the production environment. + +MindSpore Golden Stick mainly uses `PTQConfig` to customize quantization and deployment, and uses the `apply` and `convert` APIs to implement quantization and deployment. You can configure whether to quantize the weight, activation, and KVCache, and configure the quantization bit in `PTQConfig`. In addition, you can configure the data calibration policy. For details, see [PTQConfig Description](#ptqconfig-description). + +The quantization procedure of MindSpore Golden Stick is as follows: + +```python +import numpy as np +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor, dtype +from mindformers.modules import Linear +from mindspore_gs.common import BackendTarget +from mindspore_gs.ptq import PTQMode, PTQConfig +from mindspore_gs.ptq.ptq import PTQ +from mindspore.dataset import GeneratorDataset + +class SimpleNet(nn.Cell): + class DecoderCell(nn.Cell): + def __init__(self, linear): + super().__init__() + self.linear = linear + + def construct(self, *args, **kwargs): + return self.linear(*args, **kwargs) + + def __init__(self, foo_seq_length=1024): + super().__init__() + + self.foo_seq_length = foo_seq_length + linear = Linear(in_channels=foo_seq_length, out_channels=foo_seq_length, weight_init="ones") + self.decoder = SimpleNet.DecoderCell(linear) + + def construct(self, x): + return self.decoder(x) + + def generate(self, input_ids, do_sample=False, max_new_tokens=1): + input_ids = np.pad(input_ids, ((0, 0), (0, self.foo_seq_length - input_ids.shape[1])), 'constant', + constant_values=0) + return self.construct(Tensor(input_ids, dtype=dtype.float16)) + +def create_foo_ds(repeat=1): + class SimpleIterable: + def __init__(self, repeat=1): + self._index = 0 + self.data = [] + for _ in range(repeat): + self.data.append(np.array([[1, 1, 1]], dtype=np.int32)) + + def __next__(self): + if self._index >= len(self.data): + raise StopIteration + item = (self.data[self._index],) + self._index += 1 + return item + + def __iter__(self): + self._index = 0 + return self + + def __len__(self): + return len(self.data) + + return GeneratorDataset(source=SimpleIterable(repeat), column_names=["input_ids"]) + + +net = SimpleNet() # The float model that needs to be quantized +ds = create_foo_ds(1) +cfg = PTQConfig(mode=PTQMode.QUANTIZE, backend=BackendTarget.ASCEND, weight_quant_dtype=dtype.int8) +ptq = PTQ(cfg) +ptq.apply(net, datasets=ds) +ptq.convert(net) + +ms.save_checkpoint(net.parameters_dict(), './simplenet_ptq.ckpt') +``` + +1. Use [nn.Cell](https://www.mindspore.cn/docs/en/master/api_python/nn/mindspore.nn.Cell.html) to define the network. After the model is trained, the floating-point weights of the model are obtained. During inference, the floating-point weights of the model are loaded. The preceding example simplifies the process by directly creating a network and quantizing the network using the initial floating-point weights. +2. Use PTQConfig to set the mode to quantization and backend to Ascend for 8-bit quantization of the weights. For details, see [PTQConfig Description](#ptqconfig-description). +3. Use the apply API to convert the network into a fake-quantized network and collect statistics on the quantization objects according to `PTQConfig`. +4. Use the convert API to perform real quantization on the fake-quantized network obtained in the previous step to obtain the quantized network. + +After the quantization is complete, you can use the quantized model for inference. The procedure is as follows: + +```python +import numpy as np +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor, dtype +from mindformers.modules import Linear +from mindspore_gs.common import BackendTarget +from mindspore_gs.ptq import PTQMode, PTQConfig +from mindspore_gs.ptq.ptq import PTQ +from mindspore.dataset import GeneratorDataset + +class SimpleNet(nn.Cell): + class DecoderCell(nn.Cell): + def __init__(self, linear): + super().__init__() + self.linear = linear + + def construct(self, *args, **kwargs): + return self.linear(*args, **kwargs) + + def __init__(self, foo_seq_length=1024): + super().__init__() + + self.foo_seq_length = foo_seq_length + linear = Linear(in_channels=foo_seq_length, out_channels=foo_seq_length, weight_init="ones") + self.decoder = SimpleNet.DecoderCell(linear) + + def construct(self, x): + return self.decoder(x) + + def generate(self, input_ids, do_sample=False, max_new_tokens=1): + input_ids = np.pad(input_ids, ((0, 0), (0, self.foo_seq_length - input_ids.shape[1])), 'constant', + constant_values=0) + return self.construct(Tensor(input_ids, dtype=dtype.float16)) + +net = SimpleNet() +cfg = PTQConfig(mode=PTQMode.DEPLOY, backend=BackendTarget.ASCEND, weight_quant_dtype=dtype.int8) +ptq = PTQ(cfg) +ptq.apply(net) +ptq.convert(net) +ms.load_checkpoint('./simplenet_ptq.ckpt', net) + +input = Tensor(np.ones((5, 1024), dtype=np.float32), dtype=dtype.float32) +output = net(input) +print(output) +``` + +1. Use PTQConfig to set the mode to deployment and backend to Ascend for 8-bit quantization of the weights. For details, see [PTQConfig Description](#ptqconfig-description). +2. Use the apply and convert APIs to convert the network into a quantized network. In the deployment phase, no information statistics are collected or quantization computing is performed. Only the network structure is converted into a quantized network. +3. Load the quantized weights to the quantized network for inference. + +### PTQConfig Description + +You can customize the PTQConfig to enable different quantization capabilities. For details about PTQConfig, see the [API document](https://www.mindspore.cn/golden_stick/docs/en/master/ptq/mindspore_gs.ptq.PTQConfig.html#mindspore_gs.ptq.PTQConfig). The following lists the configuration examples of these algorithms: + +> **A** indicates activation, **W** indicates weight, **C** indicates KVCache, and the number indicates the bit. For example, A16W8 indicates that the activation is quantized to float16 and the weight is quantized to int8. + +- A16W8 weight quantization + + ```python + from mindspore import dtype as msdtype + from mindspore_gs.ptq import PTQConfig, OutliersSuppressionType + + ptq_config = PTQConfig(weight_quant_dtype=msdtype.int8, act_quant_dtype=None, kvcache_quant_dtype=None, + outliers_suppression=OutliersSuppressionType.NONE) + ``` + +- A8W8 quantization + + > A8W8 quantization is based on the [SmoothQuant](https://gitcode.com/gh_mirrors/smo/smoothquant/overview) algorithm. PTQConfig provides the **outliers_suppression** field to specify whether to perform the smooth operation. + + ```python + from mindspore import dtype as msdtype + from mindspore_gs.ptq import PTQConfig, OutliersSuppressionType + + ptq_config = PTQConfig(weight_quant_dtype=msdtype.int8, act_quant_dtype=msdtype.int8, kvcache_quant_dtype=None, + outliers_suppression=OutliersSuppressionType.SMOOTH) + ``` + +- KVCache int8 quantization + + ```python + from mindspore import dtype as msdtype + from mindspore_gs.ptq import PTQConfig, OutliersSuppressionType + + ptq_config = PTQConfig(weight_quant_dtype=None, act_quant_dtype=None, kvcache_quant_dtype=msdtype.int8, + outliers_suppression=OutliersSuppressionType.NONE) + ``` + +## Examples + +### PTQ Examples + +The following provides the complete process of quantizing and deploying the post-training quantization (PTQ) algorithm on the Llama2 network: + +- [PTQ algorithm](https://www.mindspore.cn/golden_stick/docs/en/master/ptq/ptq.html): supports 8-bit weight quantization, 8-bit full quantization, and KVCacheInt8 quantization. SmoothQuant can be used to improve the quantization precision. Combined quantization algorithms of different algorithms are supported to improve the quantization inference performance. + +### Perceptual Quantization Training Examples + +- [SimQAT algorithm](https://www.mindspore.cn/golden_stick/docs/en/master/quantization/simulated_quantization.html): A basic quantization aware algorithm based on the fake quantization technology. +- [SLB quantization algorithm](https://www.mindspore.cn/golden_stick/docs/en/master/quantization/slb.html): A non-linear low-bit quantization aware algorithm. + +### Pruning Examples + +- [SCOP pruning algorithm](https://www.mindspore.cn/golden_stick/docs/en/master/pruner/scop.html): A structured weight pruning algorithm.