diff --git a/tutorials/source_zh_cn/model_infer/images/model_infer_stack.png b/tutorials/source_zh_cn/model_infer/images/model_infer_stack.png index dbf3452a899187e994cf3bd658729fe71208ce75..3563a2765faddd2c0098e4ae5b9dafe34cb3616b 100644 Binary files a/tutorials/source_zh_cn/model_infer/images/model_infer_stack.png and b/tutorials/source_zh_cn/model_infer/images/model_infer_stack.png differ diff --git a/tutorials/source_zh_cn/model_infer/introduction.md b/tutorials/source_zh_cn/model_infer/introduction.md index 192448362c9d76f14e17055c9ba21656c9ee0427..5d765395acaa8567a3a4b55035cde0527533e100 100644 --- a/tutorials/source_zh_cn/model_infer/introduction.md +++ b/tutorials/source_zh_cn/model_infer/introduction.md @@ -74,4 +74,4 @@ MindSpore框架提供多种模型推理方式,以方便用户在面对不同 - **算子库**:MindSpore内建了各种计算的算子库,面对推理场景也包含了各种融合算子以提升推理性能。 -- **Lite推理**:面向资源受限设备,主要以C++运行时为核心,资源消耗不到1MB,适合在手机等设备上运行。同时Lite推理也有配套的服务化和Python API接口应对不同的用户场景。 +- **Lite推理**:面向资源受限设备,主要以C++运行时为核心,资源消耗不到1MB,适合在手机等设备上运行。同时,Lite推理也提供Python API接口,应对不同的用户场景。 diff --git a/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_model_infer.rst b/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_model_infer.rst index 802b1b8e0bfefb6ceaabf74123a3509b3a134084..650ba3c0ea08b1a33fdc035e83d23775d6722474 100644 --- a/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_model_infer.rst +++ b/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_model_infer.rst @@ -262,7 +262,7 @@ MindSpore大语言模型带框架推理主要依赖MindSpore开源软件,用 from mindspore import ops, mint, Tensor, dtype from qwen2 import Qwen2Config, Qwen2ModelInput, Qwen2ForCausalLM, CacheManager, sample - def generate(model: Qwen2ForCausalLM, cache_manager: CacheManager, input_ids: List, max_new_tokens: int, max_seq_lens: int, eos_token_id: int): + def generate(model: Qwen2ForCausalLM, config: Qwen2Config, cache_manager: CacheManager, input_ids: List, max_new_tokens: int, max_seq_lens: int, eos_token_id: int): batch_size = len(input_ids) assert max_seq_lens >= max(map(len, input_ids)) @@ -355,6 +355,7 @@ MindSpore大语言模型带框架推理主要依赖MindSpore开源软件,用 output = generate( model=model, + config=config, cache_manager=cache_manager, input_ids=input_ids, max_new_tokens=max_new_tokens, diff --git a/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_parallel_infer.md b/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_parallel_infer.md index 89492e3d3441b2d5920cebd9f2fc7a2fe4bd9938..f03053ed2976cf948b10d2fa7cfa44ae1619f6db 100644 --- a/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_parallel_infer.md +++ b/tutorials/source_zh_cn/model_infer/ms_infer/ms_infer_parallel_infer.md @@ -483,13 +483,15 @@ msrun --worker_num 2 --local_worker_num 2 --master_port 8124 --log_dir msrun_log ## 实践:Qwen2模型并行改造 -本章将对[从零构建大语言模型推理网络](./ms_infer_network_develop.md)中开发的Qwen2大语言模型进行并行适配,根据上述分析,可以将并行适配分为以下两个主要步骤: +本章将对[从零构建大语言模型推理网络](./ms_infer_network_develop.md)中开发的Qwen2大语言模型进行并行适配,根据上述分析,可以将并行适配分为以下三个主要步骤: 1. **模型网络适配**:根据上述的并行方案,对模型中的网络层进行并行切分,将计算分割到多个卡上执行。 2. **模型权重适配**:由于Linear中权重在并行切分后,shape也变化了,因此在加载模型权重时,需要对应修改。 -为了能够简化场景,本章只对Qwen2模型中的Linear进行并行度为2的切分,Embedding层的切分暂时不涉及。 +3. **KVCache适配**:由于Attention分数计算时的数量计算也根据并行度切分了,因此在KVCache管理中也要对应更新shape。 + +为了能够简化场景,本章只对Qwen2模型中的Linear进行并行度为2的切分,Embedding层的切分暂时不涉及。建议,将示例中原本单卡的infer.py和qwen2.py文件,重命名为infer_parallel.py和qwen2_parallel.py,防止代码的冲突。 ### 通信组建立 @@ -519,7 +521,7 @@ class CommunicationHelper: COMMON_HELPER = None def init_communication(): - TP+GROUP_NAME = "tp" + TP_GROUP_NAME = "tp" TP_SIZE = 2 global COMMON_HELPER @@ -532,21 +534,21 @@ def init_communication(): 本方案主要对Linear层进行并行切分,因此主要的修改是对其进行修改,实现上,需要将Qwen2Linear修改为Qwen2ColParallelLinear和Qwen2RowParallelLinear两个类,分别对应列切和行切的Linear,具体代码可以参考如下: -```diff +```python from typing import Optional, Type, Tuple from mindspore import nn, ops, mint, Parameter, Tensor + class Qwen2ColParallelLinear(nn.Cell): - def __init__(self, input_size: int, output_size: int, param_dtype: Optional[Type], bias: bool) -> None: + def __init__(self, input_size: int, output_size: int, param_dtype: Optional[Type], enable_bias: bool) -> None: super().__init__() -+ self.tp_size = COMMON_HELPER.get_tensor_model_parallel_group_size() + self.tp_size = COMMON_HELPER.get_tensor_model_parallel_group_size() self.param_dtype = param_dtype self.input_size = input_size -- self.output_size = output_size -+ self.output_size = output_size // self.tp_size - self.enable_bias = bias + self.output_size = output_size // self.tp_size + self.enable_bias = enable_bias self.matmul = ops.MatMul(transpose_b=True) self.weight = Parameter( @@ -571,22 +573,22 @@ class Qwen2ColParallelLinear(nn.Cell): class Qwen2RowParallelLinear(nn.Cell): - def __init__(self, input_size: int, output_size: int, param_dtype: Optional[Type], bias: bool) -> None: + def __init__(self, input_size: int, output_size: int, param_dtype: Optional[Type], enable_bias: bool) -> None: super().__init__() -+ self.tp_size = COMMON_HELPER.get_tensor_model_parallel_group_size() + self.tp_size = COMMON_HELPER.get_tensor_model_parallel_group_size() self.param_dtype = param_dtype -- self.input_size = input_size -+ self.input_size = input_size // self.tp_size + self.input_size = input_size // self.tp_size self.output_size = output_size - self.enable_bias = bias + self.enable_bias = enable_bias self.matmul = ops.MatMul(transpose_b=True) self.weight = Parameter( mint.zeros( (self.output_size, self.input_size), dtype=self.param_dtype - ), requires_grad=False + ), + requires_grad=False ) if self.enable_bias: @@ -594,14 +596,14 @@ class Qwen2RowParallelLinear(nn.Cell): self.bias = Parameter( mint.zeros(self.output_size, dtype=self.param_dtype) ) -+ self.all_reduce = ops.AllReduce(group=COMMON_HELPER.get_tensor_model_parallel_group()) + self.all_reduce = ops.AllReduce(group=COMMON_HELPER.get_tensor_model_parallel_group()) - def construct(self, input: Tensor) -> Tuple[Tensor, bool]: + def construct(self, input: Tensor): origin_shape = input.shape x = self.matmul(input.view(-1, origin_shape[-1]), self.weight) if self.enable_bias: x = self.bias_add(x, self.bias) -+ x = self.all_reduce(x) + x = self.all_reduce(x) return x.view(*origin_shape[:-1], -1) ``` @@ -646,25 +648,29 @@ class Qwen2Attention(nn.Cell): + self.paged_attn = PagedAttention(self.num_heads // self.tp_size, self.scaling, self.num_kv_heads // self.tp_size) self.reshape_and_cache = ops.auto_generate.ReshapeAndCache() - self.q_proj = Qwen2ColParallelLinear( +- self.q_proj = Qwen2Linear( ++ self.q_proj = Qwen2ColParallelLinear( input_size=self.hidden_size, output_size=self.q_size, param_dtype=self.param_dtype bias=True ) - self.k_proj = Qwen2ColParallelLinear( +- self.q_proj = Qwen2Linear( ++ self.q_proj = Qwen2ColParallelLinear( input_size=self.hidden_size, output_size=self.kv_size, param_dtype=self.param_dtype, bias=True ) - self.v_proj = Qwen2ColParallelLinear( +- self.q_proj = Qwen2Linear( ++ self.q_proj = Qwen2ColParallelLinear( input_size=self.hidden_size, output_size=self.kv_size, param_dtype=self.param_dtype, bias=True ) - self.o_proj = Qwen2RowParallelLinear( +- self.q_proj = Qwen2Linear( ++ self.q_proj = Qwen2RowParallelLinear( input_size=self.q_size, output_size=self.hidden_size, param_dtype=self.param_dtype, @@ -685,13 +691,21 @@ class Qwen2Attention(nn.Cell): q_seq_lens: Tensor) -> Tensor: bs, seq_len, hidden_dim = hidden_state.shape -- q = self.q_proj(hidden_state).view(-1, self.q_size // self.tp_size) -- k = self.k_proj(hidden_state).view(-1, self.kv_size // self.tp_size) -- v = self.v_proj(hidden_state).view(-1, self.kv_size // self.tp_size) +- q = self.q_proj(hidden_state).view(-1, self.q_size) +- k = self.k_proj(hidden_state).view(-1, self.kv_size) +- v = self.v_proj(hidden_state).view(-1, self.kv_size) + q = self.q_proj(hidden_state).view(-1, self.q_size // self.tp_size) + k = self.k_proj(hidden_state).view(-1, self.kv_size // self.tp_size) + v = self.v_proj(hidden_state).view(-1, self.kv_size // self.tp_size) + q, k = self.rotary_emb( + positions, + q, + k, + batch_valid_length, + is_prefill + ) + k = k.contiguous() v = v.contiguous() @@ -730,19 +744,22 @@ class Qwen2MLP(nn.Cell): def __init__(self, config: Qwen2Config) -> None: super().__init__() - self.up_proj = Qwen2ColParallelLinear( +- self.q_proj = Qwen2Linear( ++ self.up_proj = Qwen2ColParallelLinear( input_size=config.hidden_size, output_size=config.intermediate_size, param_dtype=config.param_dtype, bias=False ) - self.gate_proj = Qwen2ColParallelLinear( +- self.q_proj = Qwen2Linear( ++ self.gate_proj = Qwen2ColParallelLinear( input_size=config.hidden_size, output_size=config.intermediate_size, param_dypte=config.param_dtype, bias=False ) - self.down_proj = Qwen2RowParallelLinear( +- self.q_proj = Qwen2Linear( ++ self.down_proj = Qwen2RowParallelLinear( input_size=config.intermediate_size, output_size=config.hidden_size, param_dtype=config.param_dtype, @@ -756,6 +773,8 @@ class Qwen2MLP(nn.Cell): +class GatherLastDim(nn.Cell): + def __init__(self): ++ super().__init__() ++ + self.all_gather = ops.AllGather(group=COMMON_HELPER.get_tensor_model_parallel_group()) + self.world_size = COMMON_HELPER.get_tensor_model_parallel_group_size() + self.split = ops.Split(axis=0, output_num=self.world_size) @@ -771,7 +790,8 @@ class Qwen2ForCausalLM(nn.Cell): super().__init__() self.model = Qwen2Model(config=config) - self.lm_head = Qwen2ColParallelLinear( +- self.q_proj = Qwen2Linear( ++ self.lm_head = Qwen2ColParallelLinear( input_size=config.hidden_size, output_size=config.vocab_size, param_dtype=config.param_dtype, @@ -810,14 +830,14 @@ from typing import Optional, Type, Tuple from mindspore import nn, ops, mint, Parameter, Tensor class Qwen2ColParallelLinear(nn.Cell): - def __init__(self, input_size: int, output_size: int, param_dtype: Optional[Type], bias: bool) -> None: + def __init__(self, input_size: int, output_size: int, param_dtype: Optional[Type], enable_bias: bool) -> None: super().__init__() self.tp_size = COMMON_HELPER.get_tensor_model_parallel_group_size() self.param_dtype = param_dtype self.input_size = input_size self.output_size = output_size // self.tp_size - self.enable_bias = bias + self.enable_bias = enable_bias self.matmul = ops.MatMul(transpose_b=True) self.weight = Parameter( @@ -855,14 +875,14 @@ class Qwen2ColParallelLinear(nn.Cell): class Qwen2RowParallelLinear(nn.Cell): - def __init__(self, input_size: int, output_size: int, param_dtype: Optional[Type], bias: bool) -> None: + def __init__(self, input_size: int, output_size: int, param_dtype: Optional[Type], enable_bias: bool) -> None: super().__init__() self.tp_size = COMMON_HELPER.get_tensor_model_parallel_group_size() self.param_dtype = param_dtype self.input_size = input_size // self.tp_size self.output_size = output_size - self.enable_bias = bias + self.enable_bias = enable_bias self.matmul = ops.MatMul(transpose_b=True) self.weight = Parameter( @@ -932,8 +952,56 @@ class Qwen2ForCausalLM(nn.Cell): 上面代码对需要自定义加载权重的网络层增加了weight_load方法,并且对其权重对象通过setattr方法设置了自定义权重加载方法,在模型权重加载时,通过读取权重的映射表,找到对应的参数对象,更新其权重。对于列切和行切的Linear,使用了Tensor的narrow获取对应偏移的数据,唯一不同是两者切分维度不同。 +### KVCache切分 + +KVCache的切分在并行度可以被num_key_value_heads整除场景下比较简单,直接将对应的shape修改即可,具体可以参考以下代码: + +```diff +class CacheManager: + def __init__(self, config: Qwen2Config, block_num: int, block_size: int, batch_size: int) -> None: + self.block_num = block_num + self.block_size = block_size + self.batch_size = batch_size + + head_dim = config.hidden_size // config.num_attention_heads + ++ self.tp_size = COMMON_HELPER.get_tensor_model_parallel_group_size() +- self.k_caches = mutable([ops.zeros((block_num, block_size, config.num_key_value_heads, head_dim), dtype=config.param_dtype) for _ in range(config.num_hidden_layers)]) +- self.v_caches = mutable([ops.zeros((block_num, block_size, config.num_key_value_heads, head_dim), dtype=config.param_dtype) for _ in range(config.num_hidden_layers)]) ++ self.k_caches = mutable([ops.zeros((block_num, block_size, config.num_key_value_heads // self.tp_size, head_dim), dtype=config.param_dtype) for _ in range(config.num_hidden_layers)]) ++ self.v_caches = mutable([ops.zeros((block_num, block_size, config.num_key_value_heads // self.tp_size, head_dim), dtype=config.param_dtype) for _ in range(config.num_hidden_layers)]) + self.block_tables = [[] for _ in range(batch_size)] + self.acc_slot_mapping = [[] for _ in range(batch_size)] + self.free_block_ids = deque(range(block_num)) + + def step(self, start_pos_idx: int, token_num_per_batch: int) -> Tuple[Tensor, Tensor]: +``` + +由代码可以看出,只需要将KVCache初始化的shape稍作调整,即可以完成KVCache的并行适配。 + ### 并行执行 +由于并行执行需要初始化通信域,还需要在infer_paralle.py的初始化阶段调用init_communication函数,具体建议在set_context后面执行,可以参考如下代码: + +```diff + + import os + import mindspore as ms +- from qwen2_parallel import Qwen2Config, Qwen2ForCausalLM, CacheManager ++ from qwen2_parallel import Qwen2Config, Qwen2ForCausalLM, CacheManager, init_communication + from mindspore import Tensor, mint + + # set mindspore context and envs + os.environ["MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST"] = "PagedAttention" + + ms.set_context(infer_boost="on") + ms.set_context(mode=ms.context.PYNATIVE_MODE) + ++ init_communication() + + model_path = "/path/to/model" +``` + 完成模型适配和权重适配后,可以通过以下命令启动多卡执行: ```shell