diff --git a/tutorials/experts/source_zh_cn/parallel/sharding_propagation.md b/tutorials/experts/source_zh_cn/parallel/sharding_propagation.md index fb2850c7ad1436a0021e4892ce44a5784952186f..df30cf43cdcf2bd465991ef504cbae98fb188d78 100644 --- a/tutorials/experts/source_zh_cn/parallel/sharding_propagation.md +++ b/tutorials/experts/source_zh_cn/parallel/sharding_propagation.md @@ -4,7 +4,7 @@ ## 概述 -切分策略传播算法是指,用户仅需要手工定义几个关键算子的策略,计算图中其余算子的策略由算法自动生成。因为关键算子的策略已被定义,该算法的代价模型主要描述的算子之间的重排布代价(Redistribution Cost),优化目标为全图重排代价最小。因为已经定义了主要算子策略,相当于认为压缩了搜索空间,这种方案的搜索时间较短,其策略性能依赖于关键算子策略的定义,因此仍然要求用户具备一定的分析定义策略的能力。 +切分策略传播算法是指,用户仅需要手工定义几个关键算子的策略,计算图中其余算子的策略由算法自动生成。因为关键算子的策略已被定义,该算法的代价模型主要描述的算子之间的重排布代价(Redistribution Cost),优化目标为全图重排与计算代价最小。因为已经定义了主要算子策略,相当于认为压缩了搜索空间,这种方案的搜索时间较短,其策略性能依赖于关键算子策略的定义,因此仍然要求用户具备一定的分析定义策略的能力。 > 切分策略传播算法支持的硬件平台包括Ascend、GPU,此外还同时支持PyNative模式和Graph模式。 @@ -12,21 +12,21 @@ 1. `mindspore.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, search_mode="sharding_propagation")`:设置并行模式为自动并行,且搜索模式为切分策略传播算法。 -2. `mindspore.nn.Cell.shard()`以及`mindspore.ops.Primitive.shard()`:指定算子切分策略,其余算子的策略通过传播算法推导得到。目前`mindspore.nn.Cell.shard()`接口仅可在PyNative模式下使用;`mindspore.ops.Primitive.shard()`接口仅可在图模式下使用。 +2. `mindspore.nn.Cell.shard()`以及`mindspore.ops.Primitive.shard()`:指定算子切分策略,其余算子的策略通过传播算法推导得到。目前`mindspore.nn.Cell.shard()`接口同时支持PyNative模式与图模式;`mindspore.ops.Primitive.shard()`接口仅可在图模式下使用。 总而言之,切分策略传播算法需要用户手动配置关键算子的切分策略。 ## 基本原理 -给定计算图,**切分策略传播**(Sharding Propagation)使得策略由配置的算子传播到整张计算图的所有算子。在传播过程中,策略选择的目标是最小化张量重排布产生的通信代价。 +给定计算图,**切分策略传播**(Sharding Propagation)使得策略由配置的算子传播到整张计算图的所有算子。在传播过程中,策略选择的目标是最小化张量重排布产生的通信代价与计算代价。 切分策略传播的输入是带有一些算子切分策略的计算图,其中的点表示算子,有向边表示数据依赖关系。切分策略传播的执行流程如下: 1. 为未配置切分策略的算子生成可行的切分策略; 2. 为每条边生成重排布策略及相应的代价; -3. 从已配置切分策略的算子出发,利用广度优先搜索将切分策略传播到其它未配置切分策略的算子。传播过程的目标是最小化每条边上的重排布通信代价。若在当前传播路径上遇到已配置策略的算子,则停止传播。 +3. 从已配置切分策略的算子出发,利用广度优先搜索将切分策略传播到其它未配置切分策略的算子。传播过程的目标是最小化每条边上的重排布通信与计算代价。若在当前传播路径上遇到已配置策略的算子,则停止传播。 -下图所示的是切分策略传播的一个流程实例。在给定带有切分策略的计算图后,首先为没有配置切分策略的算子枚举其可行的策略,如下图(b)所示。然后,为每条边枚举重排布策略和相应的代价。如下图(c)所示,这里的重排布策略定义为二元组[*s_strategy*, *t_strategy*],其中*s_strategy*表示的是源算子(下图(c)中的ReLU)的切分策略,*t_strategy*表示的是目的算子(下图(c)中的MatMul)的切分策略。当沿着一条边传播到下一个算子时(如图中ReLU切分策略已确定,为[2, 4],下一步要决定MatMul算子的切分策略),总是在表中选择引起通信代价最小的目的算子的策略(即为MatMul选择[[2, 4], [4, 1]])。最后,所有算子的切分策略都被确定,如下图(d)所示。 +下图所示的是切分策略传播的一个流程实例。在给定带有切分策略的计算图后,首先为没有配置切分策略的算子枚举其可行的策略,如下图(b)所示。然后,为每条边枚举重排布策略和相应的代价。如下图(c)所示,这里的重排布策略定义为二元组[*s_strategy*, *t_strategy*],其中*s_strategy*表示的是源算子(下图(c)中的ReLU)的切分策略,*t_strategy*表示的是目的算子(下图(c)中的MatMul)的切分策略。当沿着一条边传播到下一个算子时(如图中ReLU切分策略已确定,为[2, 4],下一步要决定MatMul算子的切分策略),总是在表中选择产生代价最小的目的算子的策略(即为MatMul选择[[2, 4], [4, 1]])。最后,所有算子的切分策略都被确定,如下图(d)所示。 ![sharding_propagation](./images/sharding_propagation_zh.png "切分策略传播的流程实例") diff --git a/tutorials/experts/source_zh_cn/parallel/split_technique.md b/tutorials/experts/source_zh_cn/parallel/split_technique.md index 3ace1eefb918e7c89038289f13acfa5eb9d66d37..bda7bafd7abba28beff6abae67ab165981d3398c 100644 --- a/tutorials/experts/source_zh_cn/parallel/split_technique.md +++ b/tutorials/experts/source_zh_cn/parallel/split_technique.md @@ -4,24 +4,192 @@ ## 概述 -给定一个新模型,从用户的角度,关键问题是配置哪些算子的切分策略来获得较好的性能。由于策略传播的目标是最小化张量重排布的代价,而非最小化端到端的迭代时间,因此,为“关键算子”配置合适的切分策略是十分重要的。然而,并不存在明确的规定约束哪些算子是必须配置切分策略的。尽管如此,基于我们训练大模型的经验,确实有一些原则可以用来指导新用户配置并行策略。这里,我们列出3条经验性的原则。 +对于一个新模型使用sharding propagation来配置并行策略,关键问题在于配置哪些算子的切分策略来获得较好的性能。由于策略传播的目标是最小化张量重排布的代价,而非最小化端到端的迭代时间,因此,为“关键算子”配置合适的切分策略是十分重要的。然而,并不存在明确的规定约束哪些算子是必须配置切分策略的。尽管如此,基于我们训练大模型的经验,确实有一些原则可以用来指导新用户配置并行策略。这里,我们列出4条经验性的原则。 -### 配置涉及权重的算子 +#### 配置涉及权重的算子 参数权重的切分策略是十分重要的,尤其对大模型来说,因为参数权重引起的内存消耗占据模型训练总内存消耗的大部分。因此,涉及权重的算子通常需要显式地配置切分策略。在下图的两个例子中,涉及权重的Gather和MatMul算子配置了切分策略,而其他算子没有配置。这分别对应[mindformers](https://gitee.com/mindspore/mindformers/blob/master/mindformers/modules/transformer/transformer.py)中的数据并行VocabEmbedding层和混合并行FeedForward层。 ![sp_case1_zh](./images/sp_case1_zh.png "配置涉及权重的算子") +#### 配置维度改变/轴改变的算子 -### 配置维度改变的算子 - -深度学习框架的算子大致可以分为两类:语义简单的维度保持的算子;会改变输入张量维度的算子。对于维度保持算子,策略传播算法可以较容易地将切分策略传播出去。但是,对于维度改变算子,显式地配置切分策略才能更好地表达用户的初始想法,避免策略传播算法推导出非用户期望的切分策略。在下图的例子中,ReduceMean和MatMul是维度改变算子,它们被配置了切分策略。 +深度学习框架的算子大致可以分为两类:语义简单的维度保持的算子;会改变输入张量维度的算子。对于维度保持算子,策略传播算法可以较容易地将切分策略传播出去。但是,对于维度改变算子,显式地配置切分策略才能更好地表达用户的初始想法,避免策略传播算法推导出非用户期望的切分策略。常见的维度改变/轴改变算子有:[ReduceMean](https://www.mindspore.cn/docs/zh-CN/master/api_python/ops/mindspore.ops.ReduceMean.html)、[ReduceSum](https://www.mindspore.cn/docs/zh-CN/master/api_python/ops/mindspore.ops.ReduceSum.html)、[Transpose](https://www.mindspore.cn/docs/zh-CN/r2.3.1/api_python/ops/mindspore.ops.transpose.html)、[StridedSlice](https://www.mindspore.cn/docs/zh-CN/r2.3.1/api_python/ops/mindspore.ops.StridedSlice.html)以及[Matmul](https://www.mindspore.cn/docs/zh-CN/master/api_python/ops/mindspore.ops.MatMul.html)与[BatchMatmul](https://www.mindspore.cn/docs/zh-CN/master/api_python/ops/mindspore.ops.BatchMatMul.html)。在下图的例子中,ReduceMean和MatMul是维度改变算子,它们被配置了切分策略。 ![sp_case2_zh](./images/sp_case2_zh.png "配置维度改变的算子") -### 配置并行方式改变的边界算子 +#### 配置并行策略改变的边界算子 -对于类似ResNet类的模型,模型的不同部分偏好的并行方式不同:前半部分使用数据并行,后半部分使用模型并行,以此获得最优的迭代性能。这可以通过为并行方式改变的边界算子配置策略来实现。在下图的例子中,第一个MatMul配置了数据并行的策略,它会将数据并行的策略向前传播到模型的前半部分;第二个MatMul配置了模型并行的策略,它会将模型并行的策略向后传播到模型的后半部分。 +对于类似ResNet类的模型,模型的不同部分偏好的并行方式不同:前半部分使用数据并行,后半部分使用模型并行,以此获得最优的迭代性能。对于LLAMA类大模型,当vocab_size过大时,出于对内存的考虑,可能会选择模型并行切分;当sequence_length过大时,也可能会选择序列并行的策略。以上策略属于用户根据模型和硬件信息精心配置的策略。sharding propagation是一种朴素的寻找重排代价最低的算法,并不能自动找到精心配置的策略,因此,对于用户精心调优的算子策略,需要对其进行专门的配置。在下图的例子中,第一个MatMul配置了数据并行的策略,它会将数据并行的策略向前传播到模型的前半部分;第二个MatMul配置了模型并行的策略,它会将模型并行的策略向后传播到模型的后半部分。 ![sp_case3_zh](./images/sp_case3_zh.png "配置并行方式改变的边界算子") + +#### 配置融合算子 + +对于融合大算子,如F[lashAttentionScore](https://www.mindspore.cn/lite/api/zh-CN/master/generate/classmindspore_ops_FlashAttentionScore.html#exhale-class-classmindspore-ops-flashattentionscore)、[RmsNorm](https://www.mindspore.cn/lite/api/zh-CN/master/generate/classmindspore_ops_RmsNorm.html),也是需要用户手动配置策略的算子,融合算子的输入与输出逻辑相对复杂,对于策略传播算法很难传播出最优的策略,这些算子也需要显式配置算子级策略。 + 用户在用策略传播时不仅需要对其传播算法本身有一定的了解,还要对要训练的模型的并行方式有一定的理解。如果存在某个由策略传播算法决定的算子的并行策略不符合用户的期望,那总可以通过多配置一个算子并行策略的方式解决。实际中,对于一个新模型,确实需要尝试几次才能获得性能较优的整体并行配置。 + +## 配置代码样例 +以Mindformers中封装的类[RowParallelLinear](https://gitee.com/mindspore/mindformers/blob/r1.2.0/mindformers/experimental/distri_cores/tensor_parallel/layers.py)为例: + + + + + + +
+ +```diff +# 如果使用半自动,需要调用shard方法配置所有算子的策略 +class RowParallelLinear(nn.Cell): + def shard(self, config: TransformerConfig) -> None: + dp = config.data_parallel + tp = config.tensor_parallel + cp = config.context_parallel + if self.transpose_b: + weight_strategy = (tp, 1) + else: + weight_strategy = (1, tp) + matmul_in_strategy = ((dp * cp, 1), weight_strategy) + self.matmul.shard(in_strategy=matmul_in_strategy) +- if not self.skip_bias_add: +- dd_in_strategy = ((dp * cp, tp), (tp,)) +- self.add.shard(in_strategy=add_in_strategy) +``` + + + + +```diff +# 而使用策略传播,只需要配置其中MatMul算子的策略,无需配置Add算子: +class RowParallelLinear(nn.Cell): + def shard(self, config: TransformerConfig) -> None: + dp = config.data_parallel + tp = config.tensor_parallel + cp = config.context_parallel + if self.transpose_b: + weight_strategy = (tp, 1) + else: + weight_strategy = (1, tp) + matmul_in_strategy = ((dp * cp, 1), weight_strategy) + self.matmul.shard(in_strategy=matmul_in_strategy) +``` + +
+ + +另一个例子是[CoreAttention](https://gitee.com/mindspore/mindformers/blob/r1.2.0/mindformers/experimental/distri_cores/transformer/transformer.py),根据上述原则配置: + + + + + +
+ +```diff +# 半自动配置 +class CoreAttention(nn.Cell): + def shard(self, config: TransformerConfig): + dp = config.data_parallel + tp = config.tensor_parallel + cp = config.context_parallel +- dropout_strategy = (dp, tp, cp, 1) +- self.dropout.shard(strategy=dropout_strategy) + self.bmm_qkv.shard(((dp, tp, cp, 1), (dp, tp, 1, 1))) +- self.mul.shard(((dp, tp, cp, 1), ())) + self.bmm_qk.shard(((dp, tp, cp, 1), (dp, tp, 1, 1))) + self.merge_head_transpose.shard(((dp, tp, cp, 1),)) +``` + + + +```diff +# 策略传播配置代码如下,仅需配置Matmul与Transpose算子: +class CoreAttention(nn.Cell): + def shard(self, config: TransformerConfig): + dp = config.data_parallel + tp = config.tensor_parallel + cp = config.context_parallel + self.bmm_qkv.shard(((dp, tp, cp, 1), (dp, tp, 1, 1))) + self.bmm_qk.shard(((dp, tp, cp, 1), (dp, tp, 1, 1))) + self.merge_head_transpose.shard(((dp, tp, cp, 1),)) +``` + +
+ +再看[FlashAttention](https://gitee.com/mindspore/mindformers/blob/dev/mindformers/modules/flash_attention.py)的例子: + + + + + +
+ +```diff +# 半自动配置如下 +class FlashAttention(Cell): + def shard(self, parallel_config): + dp = parallel_config.data_parallel + mp = parallel_config.model_parallel + cp = parallel_config.context_parallel + cp_ds = parallel_config.get_ulysses_cp_num() + fa_strategies = self._generate_flash_attention_strategy(dp, mp, cp, cp_ds) + self.flash_attention.shard(fa_strategies) +- if self.use_alibi_mask: +- self.alibi_rescale_mul.shard(((dp, mp, cp, 1), (1,))) +- return self +``` + + + +```diff +# 策略传播配置代码如下,需要配置FlashAttentionScore算子,无需配置Mul算子: +class FlashAttention(Cell): + def shard(self, parallel_config): + dp = parallel_config.data_parallel + mp = parallel_config.model_parallel + cp = parallel_config.context_parallel + cp_ds = parallel_config.get_ulysses_cp_num() + fa_strategies = self._generate_flash_attention_strategy(dp, mp, cp, cp_ds) + self.flash_attention.shard(fa_strategies) + return self +``` + +
+ + +若直接使用MindFormers中开源且已经配好策略的类,则外部网络无需对算子再配置shard策略,如[LlamaForCausalLM](https://gitee.com/mindspore/mindformers/blob/dev/mindformers/models/llama/llama.py)。 + + + + + +
+ +```diff +# 使用半自动配置 +class LlamaForCausalLM(LlamaPretrainedModel): + def shard(self, config: TransformerConfig): +- dp = config.data_parallel +- slice_in_strategy = ((dp, 1),) +- self.slice.shard(in_strategy=slice_in_strategy) +- not_equal_in_strategy = ((dp, 1), ()) +- self.not_equal.shard(in_strategy=not_equal_in_strategy) +- mul_in_strategy = ((dp, 1), (dp, 1)) +- self.mul.shard(in_strategy=mul_in_strategy) +- return self +``` + + + +```diff +# 使用策略传播无需配置其他算子策略 +class LlamaForCausalLM(LlamaPretrainedModel): + def shard(self, config: TransformerConfig): ++ pass +``` + +
+ +**用户无法确认是否需要对算子配置策略时,可以不配置,由算法传播找寻最优策略,但是可能无法获得最佳的并行效果;若用户能够确认该算子需要配置什么策略,则可以进行配置帮助算法获得预期效果。** \ No newline at end of file